diff --git a/pom.xml b/pom.xml
index ec79c41..ce248f1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -240,6 +240,11 @@
org.springframework.boot
spring-boot-starter-oauth2-resource-server
+
+ com.github.seancfoley
+ ipaddress
+ 5.5.1
+
diff --git a/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java b/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java
index 4a459b3..8547038 100644
--- a/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java
+++ b/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java
@@ -44,37 +44,11 @@ public class LoggingValve extends LoggingValveBase implements EventCreator {
Collections.unmodifiableList(Arrays.asList(EVENT_ID, SESSION_ID, METHOD, IP_ADDRESS, REQUEST_URI, COMMON_NAME));
private static final String REST_ADS = "/rest/ads";
- // Keep mappings for at most one minute.
- private static final int MAX_AGE = 60 * 1000;
- @SuppressWarnings("unused")
+ @SuppressWarnings("unused")
private ScheduledFuture> adsMonitor =
Executors.newSingleThreadScheduledExecutor(r -> new Thread(r, "ADS Monitor"))
.scheduleAtFixedRate(this::monitorADSRequests, 0, 15, TimeUnit.SECONDS);
private static final ConcurrentHashMap adsRequests = new ConcurrentHashMap<>();
- private Map map = new LinkedHashMap<>();
-
- private static class LoggingValveEvent implements Event {
- private final String id;
- private Date date;
- private int refs;
-
- private LoggingValveEvent(String id, Date date) {
- refs = 1;
- this.id = id;
- this.date = date;
- }
-
- @Override
- public String getId() {
- return id;
- }
-
- @Override
- public Date getDate() {
- return date;
- }
- }
-
@Autowired
public LoggingValve(PrincipalService principalService) {
this.principalService = principalService;
diff --git a/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java b/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java
index cc001ee..a0cdcea 100644
--- a/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java
+++ b/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java
@@ -500,8 +500,6 @@ public void setResponseEchoBack(String val) {
setResponsePayloadSize(StringUtils.length(val));
}
- private static final Pattern TEST_MESSAGE_PATTERN
- = Pattern.compile("^(([A-Z]+(AIRA|TEST)\\^[A-Z]+(AIRA|TEST))|([A-Z]*IZG[A-Z]*))\\^", Pattern.CASE_INSENSITIVE);
/**
* Indicates if the message matches known test patterns
* @param message The message
diff --git a/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java b/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java
index ac2dfd4..8adb385 100644
--- a/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java
+++ b/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java
@@ -1,13 +1,7 @@
package gov.cdc.izgateway.logging.info;
import java.io.Serializable;
-import java.security.cert.X509Certificate;
import java.util.Date;
-import java.util.Map;
-
-
-import gov.cdc.izgateway.security.IzgPrincipal;
-import gov.cdc.izgateway.utils.X500Utils;
import org.apache.commons.lang3.StringUtils;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.annotation.JsonFormat.Shape;
@@ -20,8 +14,6 @@
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
-import javax.security.auth.x500.X500Principal;
-
/**
* An endpoint describes the inbound or outbound connection to
* IZ Gateway during a transaction. It is abstract to ensure
diff --git a/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java b/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java
index 02aacd3..2c2ddd8 100644
--- a/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java
+++ b/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java
@@ -7,7 +7,6 @@
import gov.cdc.izgateway.security.IzgPrincipal;
import gov.cdc.izgateway.security.principal.CertificatePrincipalProviderImpl;
-import gov.cdc.izgateway.utils.X500Utils;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
@@ -68,7 +67,7 @@ public void setPrincipal(IzgPrincipal principal) {
* @param certificate
*/
public void setCertificate(X509Certificate certificate) {
- setPrincipal(CertificatePrincipalProviderImpl.createPrincipalFromCertificate(certificate));
+ setPrincipal(CertificatePrincipalProviderImpl.createPrincipalFromCertificate(certificate));
}
}
diff --git a/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java b/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java
index 20a2eee..18e6894 100644
--- a/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java
+++ b/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java
@@ -4,7 +4,6 @@
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
-import java.sql.Timestamp;
import java.util.Date;
import java.util.ServiceConfigurationError;
@@ -51,7 +50,6 @@ static MessageDigest getMessageDigest() {
*
* @param cert The certificate to compute the thumbprint.
* @return A string representing the thumbprint using SHA-1
- * @throws CertificateEncodingException if the binary encoding of the certificate cannot be produced.
*/
static String computeThumbprint(X509Certificate cert) {
if (cert == null) {
diff --git a/src/main/java/gov/cdc/izgateway/model/IDestination.java b/src/main/java/gov/cdc/izgateway/model/IDestination.java
index 299d4c7..cee1a1b 100644
--- a/src/main/java/gov/cdc/izgateway/model/IDestination.java
+++ b/src/main/java/gov/cdc/izgateway/model/IDestination.java
@@ -36,10 +36,12 @@ static class Map extends MappableEntity {}
IDestinationId getId();
+ @JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT)
Date getMaintEnd();
String getMaintReason();
+ @JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT)
Date getMaintStart();
String getMsh22();
@@ -53,11 +55,13 @@ static class Map extends MappableEntity {}
String getMsh6();
String getPassword();
+
+ Date getPassExpiry();
String getRxa11();
String getUsername();
-
+
void setDestUri(String destUri);
void setDestVersion(String destVersion);
@@ -73,7 +77,6 @@ static class Map extends MappableEntity {}
void setMaintReason(String maintReason);
- @JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT)
void setMaintStart(Date maintStart);
void setMsh22(String msh22);
@@ -88,6 +91,8 @@ static class Map extends MappableEntity {}
@JsonIgnore
void setPassword(String password);
+
+ void setPassExpiry(Date expiry);
void setRxa11(String rxa11);
diff --git a/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java b/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java
index d0214b2..1d9713d 100644
--- a/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java
+++ b/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java
@@ -87,8 +87,8 @@ private String[] getControllerPrefix(Class> controller) {
/**
* Register a controller class in this registry. Does the bulk of the work for above method.
* This method exists to enable testing without instantiating the class (possibly expensive).
- *
- * @param controller The class of the @RestController object to register
+
+ * @param controllerClass The class of the @RestController object to register
* @param prefix The prefix under which this controller is installed.
*/
public void register(Class> controllerClass, String prefix) {
@@ -150,6 +150,7 @@ private void registerPathsAndRoles(String prefix, RolesAllowed roles, MergedAnno
/**
* Dynamically add an access control for a path.
+ * @param methods The methods that apply to the path
* @param path The path to add the access control for.
* @param roles The roles to add the access control for.
*/
@@ -161,6 +162,7 @@ public void register(RequestMethod[] methods, String path, String ... roles ) {
/**
* Dynamically remove an access control for a path.
+ * @param methods The methods used with the path.
* @param path The path to add the access control for.
* @param roles The roles to add the access control for.
*/
diff --git a/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java b/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java
index f4a561a..3e8c921 100644
--- a/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java
+++ b/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java
@@ -1,10 +1,16 @@
package gov.cdc.izgateway.security;
import lombok.Data;
+import lombok.EqualsAndHashCode;
import java.math.BigInteger;
+/**
+ * A class representing a principal based on an X.509 certificate
+ * @author Audacious Inquiry
+ */
@Data
+@EqualsAndHashCode(callSuper=true)
public class CertificatePrincipal extends IzgPrincipal {
public String getSerialNumberHex() {
diff --git a/src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java b/src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java
new file mode 100644
index 0000000..2b02d04
--- /dev/null
+++ b/src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java
@@ -0,0 +1,40 @@
+package gov.cdc.izgateway.security;
+
+import java.net.URLDecoder;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.nio.charset.StandardCharsets;
+import org.bouncycastle.util.io.pem.PemObject;
+import org.bouncycastle.util.io.pem.PemReader;
+import java.io.ByteArrayInputStream;
+import java.io.StringReader;
+import java.security.cert.*;
+
+/**
+ * CertificateProcessor provides utility methods for processing certificates.
+ */
+public class CertificateProcessor {
+ public static X509Certificate processCertificateFromHeader(String certHeader) throws CertificateException {
+ certHeader = normalizeCertHeader(certHeader);
+ return parsePemCertificate(certHeader);
+ }
+
+ private static String normalizeCertHeader(String certHeader) {
+ return URLDecoder.decode(certHeader.replace("+", "%2B"), StandardCharsets.UTF_8);
+ }
+
+ private static X509Certificate parsePemCertificate(String pemContent) throws CertificateException {
+ try (PemReader pemReader = new PemReader(new StringReader(pemContent))) {
+ PemObject pemObject = pemReader.readPemObject();
+ CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
+ return (X509Certificate) certFactory.generateCertificate(
+ new ByteArrayInputStream(pemObject.getContent()));
+ } catch (Exception e) {
+ throw new CertificateException("Failed to parse certificate", e);
+ }
+ }
+
+ // Private constructor to prevent instantiation
+ private CertificateProcessor() {
+ }
+}
diff --git a/src/main/java/gov/cdc/izgateway/security/CertificateValidator.java b/src/main/java/gov/cdc/izgateway/security/CertificateValidator.java
new file mode 100644
index 0000000..ac0efea
--- /dev/null
+++ b/src/main/java/gov/cdc/izgateway/security/CertificateValidator.java
@@ -0,0 +1,30 @@
+package gov.cdc.izgateway.security;
+
+import lombok.extern.slf4j.Slf4j;
+
+import javax.net.ssl.X509TrustManager;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+
+/**
+ * CertificateValidator provides utility methods for validating certificates.
+ */
+@Slf4j
+public class CertificateValidator {
+ private final X509TrustManager trustManager;
+
+ public CertificateValidator(X509TrustManager trustManager) {
+ this.trustManager = trustManager;
+ }
+
+ public boolean isValid(X509Certificate cert) {
+ try {
+ cert.checkValidity();
+ trustManager.checkClientTrusted(new X509Certificate[]{cert}, "TLS-client-auth");
+ return true;
+ } catch (CertificateException e) {
+ log.error("Certificate validation failed", e);
+ return false;
+ }
+ }
+}
diff --git a/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java b/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java
index 5ab54e5..48afe53 100644
--- a/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java
+++ b/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java
@@ -102,6 +102,15 @@ public void afterPropertiesSet() throws IOException {
TimeUnit.SECONDS); // specified in sections
}
+ public TrustManager[] getTrustManagers() {
+ boolean reload = checkForUpdates();
+
+ KeyStore trustStore = loadTrustStore(reload);
+ TrustManager[] tm = { getTrustManager(trustStore) };
+
+ return tm;
+ }
+
public SSLContext getSSLContext() {
boolean reload = checkForUpdates();
if (sslContext != null && !reload) {
diff --git a/src/main/java/gov/cdc/izgateway/security/TrustController.java b/src/main/java/gov/cdc/izgateway/security/TrustController.java
index a475c00..c5e8557 100644
--- a/src/main/java/gov/cdc/izgateway/security/TrustController.java
+++ b/src/main/java/gov/cdc/izgateway/security/TrustController.java
@@ -49,6 +49,7 @@ public TrustController(AccessControlRegistry registry, ClientTlsSupport tlsSuppo
this.tlsConfig = tlsSupport.getConfig();
}
+ @SuppressWarnings({ "serial" })
public class TrustDataMap extends MappableEntity {}
/**
* Report on trust parameters status.
diff --git a/src/main/java/gov/cdc/izgateway/security/filter/IpAddressFilter.java b/src/main/java/gov/cdc/izgateway/security/filter/IpAddressFilter.java
new file mode 100644
index 0000000..c9c4d71
--- /dev/null
+++ b/src/main/java/gov/cdc/izgateway/security/filter/IpAddressFilter.java
@@ -0,0 +1,88 @@
+package gov.cdc.izgateway.security.filter;
+
+import gov.cdc.izgateway.logging.RequestContext;
+import gov.cdc.izgateway.logging.markers.Markers2;
+import inet.ipaddr.IPAddress;
+import inet.ipaddr.IPAddressString;
+import jakarta.servlet.*;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.core.Ordered;
+import org.springframework.core.annotation.Order;
+import org.springframework.stereotype.Component;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+@Slf4j
+@Component
+@Order(Ordered.LOWEST_PRECEDENCE)
+public class IpAddressFilter implements Filter {
+ private List allowedSubnets = Collections.emptyList();
+ private final boolean ipFilterEnabled;
+
+ public IpAddressFilter(
+ @Value("${hub.security.ip-filter.allowed-cidr:}") String allowedCidr,
+ @Value("${hub.security.ip-filter.enabled:false}") boolean ipFilterEnabled
+ ) {
+ this.ipFilterEnabled = ipFilterEnabled;
+
+ if (this.ipFilterEnabled) {
+ if (allowedCidr == null || allowedCidr.trim().isEmpty()) {
+ throw new IllegalStateException("IP filtering enabled, no IP CIDRs configured.");
+ }
+
+ this.allowedSubnets = Arrays.stream(allowedCidr.split(","))
+ .map(String::trim)
+ .filter(cidr -> !cidr.isEmpty())
+ .map(cidr -> new IPAddressString(cidr).getAddress())
+ .collect(Collectors.toList());
+
+ log.info("IP whitelist configured with {} CIDR blocks: {}", this.allowedSubnets.size(), allowedCidr);
+
+ } else {
+ log.warn("IP filtering not enabled. All IP addresses will be allowed.");
+ }
+ }
+
+ @Override
+ public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
+ HttpServletRequest httpRequest = (HttpServletRequest) servletRequest;
+
+ // Getting the remote address, not considering X-Forwarded-For currently
+ // This gets the IP from the ALB itself whereas X-Forwarded-For would get
+ // IP address of actual caller.
+ String clientIp = httpRequest.getRemoteAddr();
+
+ if (!ipFilterEnabled) {
+ filterChain.doFilter(servletRequest, servletResponse);
+ return;
+ }
+
+ boolean allowed;
+ try {
+ IPAddress ipAddress = new IPAddressString(clientIp).getAddress();
+ allowed = allowedSubnets.stream().anyMatch(subnet -> subnet.contains(ipAddress));
+
+ } catch (Exception e) {
+ // We were unable to parse the IP address, block access
+ log.error(Markers2.append(RequestContext.getSourceInfo()), "Unable to parse/verify IP: {}. Error: {}", clientIp, e.getMessage());
+ HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
+ httpResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
+ return;
+ }
+
+ if (allowed) {
+ filterChain.doFilter(servletRequest, servletResponse);
+ } else {
+ log.error(Markers2.append(RequestContext.getSourceInfo()), "Access denied for IP: {}. Not in any configured allowed CIDR.", clientIp);
+ HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
+ httpResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
+ }
+ }
+}
diff --git a/src/main/java/gov/cdc/izgateway/security/filter/SecretHeaderFilter.java b/src/main/java/gov/cdc/izgateway/security/filter/SecretHeaderFilter.java
new file mode 100644
index 0000000..ff84b3a
--- /dev/null
+++ b/src/main/java/gov/cdc/izgateway/security/filter/SecretHeaderFilter.java
@@ -0,0 +1,70 @@
+package gov.cdc.izgateway.security.filter;
+
+import gov.cdc.izgateway.logging.RequestContext;
+import gov.cdc.izgateway.logging.markers.Markers2;
+import jakarta.servlet.*;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.core.Ordered;
+import org.springframework.core.annotation.Order;
+import org.springframework.stereotype.Component;
+import gov.cdc.izgateway.security.AccessControlValve;
+import java.io.IOException;
+
+/**
+ * SecretHeaderFilter is a servlet filter that checks for a specific header in incoming requests.
+ * This is used to ensure that only requests from internal services (like ALB or WAF) are allowed to pass through.
+ * It is disabled by default, but can be enabled via the settings in the constructor.
+ * If the filter is enabled but the key or value is missing, an IllegalStateException is thrown.
+ */
+@Slf4j
+@Component
+@Order(Ordered.LOWEST_PRECEDENCE)
+public class SecretHeaderFilter implements Filter {
+ private final boolean headerFilterEnabled;
+ private final String headerFilterKey;
+ private final String headerFilterValue;
+
+ public SecretHeaderFilter(
+ @Value("${hub.security.secret-header-filter.enabled:false}") boolean headerFilterEnabled,
+ @Value("${hub.security.secret-header-filter.key:}") String headerFilterKey,
+ @Value("${hub.security.secret-header-filter.value:}") String headerFilterValue
+ ) {
+ this.headerFilterEnabled = headerFilterEnabled;
+ this.headerFilterKey = headerFilterKey;
+ this.headerFilterValue = headerFilterValue;
+
+ if (this.headerFilterEnabled) {
+ if (StringUtils.isEmpty(this.headerFilterKey) || StringUtils.isEmpty(this.headerFilterValue)) {
+ throw new IllegalStateException("Secret header filter is enabled, but the header key or value is not set.");
+ }
+ } else {
+ log.warn("Secret header filter not enabled. Requests will not be filtered.");
+ }
+ }
+
+ @Override
+ public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
+ if (!headerFilterEnabled || AccessControlValve.isLocalHost(servletRequest.getRemoteAddr())) {
+ filterChain.doFilter(servletRequest, servletResponse);
+ return;
+ }
+
+ // Check if the request is from an internal service (ALB, WAF) by checking the header
+ HttpServletRequest request = (HttpServletRequest) servletRequest;
+ HttpServletResponse response = (HttpServletResponse) servletResponse;
+ String headerValue = request.getHeader(headerFilterKey);
+
+ if (headerValue == null || !headerValue.equals(headerFilterValue)) {
+ log.error(Markers2.append(RequestContext.getSourceInfo()), "Request does not contain the secret header, rejecting request.");
+ response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
+ return;
+ }
+
+ // Continue the filter chain
+ filterChain.doFilter(servletRequest, servletResponse);
+ }
+}
diff --git a/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java b/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java
index 7c05460..10c6ab3 100644
--- a/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java
+++ b/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java
@@ -1,49 +1,79 @@
package gov.cdc.izgateway.security.principal;
import gov.cdc.izgateway.principal.provider.CertificatePrincipalProvider;
-import gov.cdc.izgateway.security.CertificatePrincipal;
-import gov.cdc.izgateway.security.IzgPrincipal;
+import gov.cdc.izgateway.security.*;
import gov.cdc.izgateway.utils.X500Utils;
import jakarta.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.catalina.Globals;
import org.apache.commons.lang3.StringUtils;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
+import javax.net.ssl.X509TrustManager;
import javax.security.auth.x500.X500Principal;
-import java.security.cert.X509Certificate;
+import java.security.cert.*;
import java.util.Map;
/**
- * A Principal Provider that uses certificates to identify the principal
- * by common name and organization.
- *
- * @author Audacious Inquiry
- *
+ * This class provides a principal from a certificate
*/
@Slf4j
@Component
public class CertificatePrincipalProviderImpl implements CertificatePrincipalProvider {
+ @Value("${client.ssl.certificate-header:}")
+ private String certHeaderKey;
+
+ private final CertificateValidator validator;
+
+ @Autowired
+ public CertificatePrincipalProviderImpl(ClientTlsSupport clientTlsSupport) {
+ this.validator = new CertificateValidator((X509TrustManager) clientTlsSupport.getTrustManagers()[0]);
+ }
@Override
public IzgPrincipal createPrincipalFromCertificate(HttpServletRequest request) {
- X509Certificate[] certs = (X509Certificate[]) request.getAttribute(Globals.CERTIFICATES_ATTR);
+ X509Certificate cert = getCertificate(request);
+ return cert != null ? createPrincipalFromCertificate(cert) : null;
+ }
+
+ /**
+ * Gets the certificate from two possible places. If a certificate is present in the request attribute,
+ * it is returned. If not, the certificate is extracted from the header. Null is returned if no certificate
+ * is found or if the certificate is invalid.
+ * @param request The request
+ * @return The certificate
+ */
+ private X509Certificate getCertificate(HttpServletRequest request) {
+ X509Certificate cert = getCertificateFromAttribute(request);
+ if (cert != null) return cert;
+
+ String certHeader = request.getHeader(certHeaderKey);
+ if (StringUtils.isBlank(certHeader)) return null;
- if (certs == null || certs.length == 0) {
+ try {
+ cert = CertificateProcessor.processCertificateFromHeader(certHeader);
+ return validator.isValid(cert) ? cert : null;
+ } catch (CertificateException e) {
+ log.error("Failed to process certificate from header", e);
return null;
}
+ }
- return createPrincipalFromCertificate(certs[0]);
+ private X509Certificate getCertificateFromAttribute(HttpServletRequest request) {
+ X509Certificate[] certs = (X509Certificate[]) request.getAttribute(Globals.CERTIFICATES_ATTR);
+ return (certs != null && certs.length > 0) ? certs[0] : null;
}
- /**
- * Create a principal from a certificate
- * @param cert The certificate
- * @return The principal
- */
- public static IzgPrincipal createPrincipalFromCertificate(X509Certificate cert) {
+ /**
+ * Create a principal from a certificate
+ * @param cert The certificate
+ * @return The principal
+ */
+ public static IzgPrincipal createPrincipalFromCertificate(X509Certificate cert) {
IzgPrincipal principal = new CertificatePrincipal();
- X500Principal subject = cert.getSubjectX500Principal();
+ X500Principal subject = cert.getSubjectX500Principal();
Map parts = X500Utils.getParts(subject);
principal.setName(parts.get(X500Utils.COMMON_NAME));
@@ -52,12 +82,13 @@ public static IzgPrincipal createPrincipalFromCertificate(X509Certificate cert)
o = parts.get(X500Utils.ORGANIZATION_UNIT);
}
principal.setOrganization(o);
-
principal.setValidFrom(cert.getNotBefore());
principal.setValidTo(cert.getNotAfter());
principal.setSerialNumber(String.valueOf(cert.getSerialNumber()));
principal.setIssuer(cert.getIssuerX500Principal().getName());
return principal;
- }
+ }
+
}
+
diff --git a/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java b/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java
index f2f4db9..edc6464 100644
--- a/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java
+++ b/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java
@@ -40,5 +40,4 @@ protected boolean validConfiguration() {
log.warn("No JWT shared secret was set. JWT authentication is disabled.");
return false;
}
-
}
diff --git a/src/main/java/gov/cdc/izgateway/soap/MockMessage.java b/src/main/java/gov/cdc/izgateway/soap/MockMessage.java
index d1ca3c2..d1c6a20 100644
--- a/src/main/java/gov/cdc/izgateway/soap/MockMessage.java
+++ b/src/main/java/gov/cdc/izgateway/soap/MockMessage.java
@@ -328,11 +328,10 @@ private static String repeatString(MockMessage response, int size) {
/**
* Get the message response this mock is intended to return.
*
- * @param resp
- * The HttpServeltResponse (may be used to return a non-XML body)
* @param message
* The SOAP message object, used again to return a non-XML body.
* @return The mocked response
+ * @throws Fault To mock a generic fault
* @throws SecurityFault
* To mock a security fault.
* @throws MessageTooLargeFault
diff --git a/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java b/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java
index f1b4bfe..d4ffd47 100644
--- a/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java
+++ b/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java
@@ -5,8 +5,6 @@
/**
* Access to this interface is supported by Faults created in this package to enable
* structured diagnostics and logging with Faults.
- *
- * @see HasFaultSupport
*/
public interface FaultSupport {
/** Name for Summary content in fault */
@@ -52,6 +50,7 @@ public interface FaultSupport {
* for automated interpretation that embodies the type of fault, the subtype and the retry strategy
* to apply. It is a 3 digit code in which the first digit indicate the fault type, the second
* provides the subtype, and the third, the retry strategy.
+ * @return The code for error
*/
String getCode();
diff --git a/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java b/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java
index 17da14a..40b399f 100644
--- a/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java
+++ b/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java
@@ -66,6 +66,7 @@ public static interface Response {}
*
* @param that The message to copy from.
* @param schema The schema to use.
+ * @param isUpgradeOrSchemaChange true if this is an upgrade or schema change
*/
public SoapMessage(SoapMessage that, String schema, boolean isUpgradeOrSchemaChange) {
// Always copy the HubHeader
diff --git a/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java b/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java
index 991ffae..589e328 100644
--- a/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java
+++ b/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java
@@ -398,18 +398,21 @@ private T readResult(Class clazz, IDestination dest,
// Mark the buffer so we can reread on error.
m = new HttpUrlConnectionInputMessage(con, clientConfig.getMaxBufferSize());
statusCode = m.getStatusCode();
+ logDestinationCertificates(con);
body = m.getBody();
m.mark();
EndPointInfo endPoint = RequestContext.getDestinationInfo();
if (ACCEPTABLE_RESPONSE_CODES.contains(statusCode)) {
result = converter.read(m, endPoint);
- if (result instanceof FaultMessage fm) {
+ if (result instanceof FaultMessage) {
m.reset();
throw HubClientFault.clientThrewFault(null, dest, statusCode, body, result);
}
return clazz.cast(result);
} else {
- throw processHttpError(dest, statusCode, con.getErrorStream());
+ try (InputStream errStream = con.getErrorStream()) {
+ throw processHttpError(dest, statusCode, errStream);
+ }
}
} catch (ClassCastException ex) {
savedEx = ex;
@@ -507,10 +510,10 @@ public static void logDestinationCertificates(HttpURLConnection con) {
DestinationInfo destination = RequestContext.getDestinationInfo();
if (destination.isConnected() && con instanceof HttpsURLConnection conx) {
try {
- X509Certificate[] certs = (X509Certificate[]) conx.getServerCertificates();
- destination.setCertificate(certs[0]);
destination.setCipherSuite(conx.getCipherSuite());
destination.setConnected(true);
+ X509Certificate[] certs = (X509Certificate[]) conx.getServerCertificates();
+ destination.setCertificate(certs[0]);
} catch (SSLPeerUnverifiedException | IllegalStateException ex) {
// Ignore this.
}
diff --git a/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java b/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java
index f5d7928..955f4ac 100644
--- a/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java
+++ b/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java
@@ -32,7 +32,6 @@
import lombok.extern.slf4j.Slf4j;
import javax.xml.namespace.QName;
-import javax.xml.stream.Location;
import javax.xml.stream.XMLStreamConstants;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
diff --git a/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java b/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java
index d764f12..fe9a305 100644
--- a/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java
+++ b/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java
@@ -5,15 +5,10 @@
import java.util.TreeMap;
import java.util.List;
import java.io.IOException;
-import java.io.InputStream;
-import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.UnaryOperator;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
-
import org.apache.commons.lang3.StringUtils;
import lombok.Getter;
@@ -106,10 +101,9 @@ private static Collection adjustNonMSHAllowedValues(Collection
/**
* Strip an HL7 Segment of allowed fields
*
- * @param b The string builder to copy the result to
* @param segment The segment to strip.
* @param allowedFields The set of allowed fields.
- * @returns The stripped segment
+ * @return The stripped segment
*/
public static String stripSegment(String segment, Collection allowedFields) {
return stripParts(segment, allowedFields, "|", HL7Utils::stripCWE);
@@ -232,7 +226,7 @@ private enum ParseState {
WITHIN_SEGMENT,
END_SEGMENT_NAME,
CHOMP_TO_DELIMITER
- };
+ }
/**
* Mash PHI in segments
* This function ensures that HL7 message content potentially containing PHI is removed from the message.
diff --git a/src/main/java/gov/cdc/izgateway/utils/ListConverter.java b/src/main/java/gov/cdc/izgateway/utils/ListConverter.java
index 1d97cd8..d948217 100644
--- a/src/main/java/gov/cdc/izgateway/utils/ListConverter.java
+++ b/src/main/java/gov/cdc/izgateway/utils/ListConverter.java
@@ -10,6 +10,8 @@
* allow one to iterate over "To" items without creating a bunch of new objects. This is especially helpful
* to translate between interface classes. It is used in Logging to convert from the LogBack IListEvent class
* to the LogEvent class in IZ Gateway mostly to serve as a way to document logging events in Swagger.
+ * @param The class to convert from
+ * @param The class to convert to
*/
public final class ListConverter extends AbstractList {
private final List events;
diff --git a/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java b/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java
index 2153dfa..901caf6 100644
--- a/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java
+++ b/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java
@@ -26,7 +26,6 @@
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.xml.sax.InputSource;
-import org.xml.sax.SAXException;
public class XmlUtils {
private static DocumentBuilder documentBuilder = getDocumentBuilder();
diff --git a/src/test/java/gov/cdc/izgateway/security/filter/IpAddressFilterTests.java b/src/test/java/gov/cdc/izgateway/security/filter/IpAddressFilterTests.java
new file mode 100644
index 0000000..4de8788
--- /dev/null
+++ b/src/test/java/gov/cdc/izgateway/security/filter/IpAddressFilterTests.java
@@ -0,0 +1,156 @@
+package gov.cdc.izgateway.security.filter;
+
+import jakarta.servlet.FilterChain;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.io.IOException;
+
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.*;
+
+@ExtendWith(MockitoExtension.class)
+public class IpAddressFilterTests {
+
+ @Mock
+ private HttpServletRequest request;
+
+ @Mock
+ private HttpServletResponse response;
+
+ @Mock
+ private FilterChain filterChain;
+
+ @BeforeEach
+ void setUp() {
+ reset(request, response, filterChain);
+ }
+
+ @Test
+ void testConstructorWithValidCidr() {
+ IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24,10.0.0.0/8", true);
+ assertNotNull(filter);
+ }
+
+ @Test
+ void testConstructorWithEmptyCidrWhenEnabled() {
+ assertThrows(IllegalStateException.class, () -> new IpAddressFilter("", true));
+ }
+
+ @Test
+ void testConstructorWithNullCidrWhenEnabled() {
+ assertThrows(IllegalStateException.class, () -> new IpAddressFilter(null, true));
+ }
+
+ @Test
+ void testConstructorWithEmptyCidrWhenDisabled() {
+ IpAddressFilter filter = new IpAddressFilter("", false);
+ assertNotNull(filter);
+ }
+
+ @Test
+ void testAllowedIpWhenFilteringEnabled() throws IOException, ServletException {
+ IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24", true);
+ when(request.getRemoteAddr()).thenReturn("192.168.1.100");
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(request, response);
+ verify(response, never()).setStatus(anyInt());
+ }
+
+ @Test
+ void testDisallowedIpWhenFilteringEnabled() throws IOException, ServletException {
+ IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24", true);
+ when(request.getRemoteAddr()).thenReturn("10.0.0.1");
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain, never()).doFilter(any(), any());
+ verify(response).setStatus(HttpServletResponse.SC_FORBIDDEN);
+ }
+
+ @Test
+ void testAnyIpAllowedWhenFilteringDisabled() throws IOException, ServletException {
+ IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24", false);
+ // IP outside the allowed range, but should not matter as filtering is disabled
+ when(request.getRemoteAddr()).thenReturn("10.0.0.1");
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(request, response);
+ verify(response, never()).setStatus(anyInt());
+ }
+
+ @Test
+ void testMultipleCidrRanges() throws IOException, ServletException {
+ IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24,10.0.0.0/8", true);
+
+ // Test first range
+ when(request.getRemoteAddr()).thenReturn("192.168.1.100");
+ filter.doFilter(request, response, filterChain);
+ verify(filterChain).doFilter(request, response);
+ reset(request, response, filterChain);
+
+ // Test second range
+ when(request.getRemoteAddr()).thenReturn("10.10.10.10");
+ filter.doFilter(request, response, filterChain);
+ verify(filterChain).doFilter(request, response);
+ reset(request, response, filterChain);
+
+ // Test outside range
+ when(request.getRemoteAddr()).thenReturn("172.16.0.1");
+ filter.doFilter(request, response, filterChain);
+ verify(filterChain, never()).doFilter(any(), any());
+ verify(response).setStatus(HttpServletResponse.SC_FORBIDDEN);
+ }
+
+ @Test
+ void testIpv6Localhost() throws IOException, ServletException {
+ IpAddressFilter filter = new IpAddressFilter("127.0.0.1/32,::1/128", true);
+
+ when(request.getRemoteAddr()).thenReturn("0:0:0:0:0:0:0:1");
+ filter.doFilter(request, response, filterChain);
+ verify(filterChain).doFilter(request, response);
+ verify(response, never()).setStatus(anyInt());
+
+ reset(request, response, filterChain);
+ }
+
+ @Test
+ void testNonLocalIpv6Address() throws IOException, ServletException {
+ // Setup with IPv4 CIDR only
+ IpAddressFilter filter = new IpAddressFilter("::1/128", true);
+
+ // Test with a non-localhost IPv6 address
+ when(request.getRemoteAddr()).thenReturn("2001:0db8:85a3:0000:0000:8a2e:0370:7334");
+ filter.doFilter(request, response, filterChain);
+
+ // Should be denied as the IPv6 address isn't in allowed CIDRs
+ verify(filterChain, never()).doFilter(any(), any());
+ verify(response).setStatus(HttpServletResponse.SC_FORBIDDEN);
+ }
+
+ @Test
+ void testAllowedIpv6Address() throws IOException, ServletException {
+ // Setup with both IPv4 and IPv6 CIDRs
+ IpAddressFilter filter = new IpAddressFilter("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128", true);
+
+ // Test with an allowed IPv6 address
+ when(request.getRemoteAddr()).thenReturn("2001:0db8:85a3:0000:0000:8a2e:0370:7334");
+ filter.doFilter(request, response, filterChain);
+
+ // Should be allowed as the IPv6 address is in allowed CIDR
+ verify(filterChain).doFilter(request, response);
+ verify(response, never()).setStatus(anyInt());
+ }
+}
diff --git a/src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java b/src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java
new file mode 100644
index 0000000..a7b7fff
--- /dev/null
+++ b/src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java
@@ -0,0 +1,86 @@
+// File: src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java
+
+package gov.cdc.izgateway.security.filter;
+
+import jakarta.servlet.FilterChain;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.io.IOException;
+
+import static org.mockito.Mockito.*;
+
+@ExtendWith(MockitoExtension.class)
+public class SecretHeaderFilterTests {
+
+ @Mock
+ private HttpServletRequest request;
+
+ @Mock
+ private HttpServletResponse response;
+
+ @Mock
+ private FilterChain filterChain;
+
+ private SecretHeaderFilter filter;
+
+ @BeforeEach
+ void setUp() {
+ // Initialize the filter with test values
+ filter = new SecretHeaderFilter(true, "x-alb-secret", "secret-value");
+ }
+
+ @Test
+ void testFilterDisabled() throws IOException, ServletException {
+ // Initialize the filter with disabled state
+ filter = new SecretHeaderFilter(false, "", "");
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(request, response);
+ verify(response, never()).setStatus(anyInt());
+ }
+
+ @Test
+ void testFilterEnabledWithoutKeyOrValue() {
+ // Expect IllegalStateException when filter is enabled but key or value is missing
+ assertThrows(IllegalStateException.class, () -> new SecretHeaderFilter(true, "", ""));
+ }
+
+ @Test
+ void testRequestWithoutHeader() throws IOException, ServletException {
+ when(request.getHeader("x-alb-secret")).thenReturn(null);
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain, never()).doFilter(any(), any());
+ verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
+ }
+
+ @Test
+ void testRequestWithInvalidHeader() throws IOException, ServletException {
+ when(request.getHeader("x-alb-secret")).thenReturn("invalid-value");
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain, never()).doFilter(any(), any());
+ verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
+ }
+
+ @Test
+ void testRequestWithValidHeader() throws IOException, ServletException {
+ when(request.getHeader("x-alb-secret")).thenReturn("secret-value");
+
+ filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(request, response);
+ verify(response, never()).setStatus(anyInt());
+ }
+}
diff --git a/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java b/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java
index 50a464e..515ee68 100644
--- a/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java
+++ b/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java
@@ -2,8 +2,6 @@
import static org.junit.jupiter.api.Assertions.*;
-import org.apache.commons.text.StringEscapeUtils;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;