diff --git a/pom.xml b/pom.xml index ec79c41..ce248f1 100644 --- a/pom.xml +++ b/pom.xml @@ -240,6 +240,11 @@ org.springframework.boot spring-boot-starter-oauth2-resource-server + + com.github.seancfoley + ipaddress + 5.5.1 + diff --git a/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java b/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java index 4a459b3..8547038 100644 --- a/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java +++ b/src/main/java/gov/cdc/izgateway/logging/LoggingValve.java @@ -44,37 +44,11 @@ public class LoggingValve extends LoggingValveBase implements EventCreator { Collections.unmodifiableList(Arrays.asList(EVENT_ID, SESSION_ID, METHOD, IP_ADDRESS, REQUEST_URI, COMMON_NAME)); private static final String REST_ADS = "/rest/ads"; - // Keep mappings for at most one minute. - private static final int MAX_AGE = 60 * 1000; - @SuppressWarnings("unused") + @SuppressWarnings("unused") private ScheduledFuture adsMonitor = Executors.newSingleThreadScheduledExecutor(r -> new Thread(r, "ADS Monitor")) .scheduleAtFixedRate(this::monitorADSRequests, 0, 15, TimeUnit.SECONDS); private static final ConcurrentHashMap adsRequests = new ConcurrentHashMap<>(); - private Map map = new LinkedHashMap<>(); - - private static class LoggingValveEvent implements Event { - private final String id; - private Date date; - private int refs; - - private LoggingValveEvent(String id, Date date) { - refs = 1; - this.id = id; - this.date = date; - } - - @Override - public String getId() { - return id; - } - - @Override - public Date getDate() { - return date; - } - } - @Autowired public LoggingValve(PrincipalService principalService) { this.principalService = principalService; diff --git a/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java b/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java index cc001ee..a0cdcea 100644 --- a/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java +++ b/src/main/java/gov/cdc/izgateway/logging/event/TransactionData.java @@ -500,8 +500,6 @@ public void setResponseEchoBack(String val) { setResponsePayloadSize(StringUtils.length(val)); } - private static final Pattern TEST_MESSAGE_PATTERN - = Pattern.compile("^(([A-Z]+(AIRA|TEST)\\^[A-Z]+(AIRA|TEST))|([A-Z]*IZG[A-Z]*))\\^", Pattern.CASE_INSENSITIVE); /** * Indicates if the message matches known test patterns * @param message The message diff --git a/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java b/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java index ac2dfd4..8adb385 100644 --- a/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java +++ b/src/main/java/gov/cdc/izgateway/logging/info/EndPointInfo.java @@ -1,13 +1,7 @@ package gov.cdc.izgateway.logging.info; import java.io.Serializable; -import java.security.cert.X509Certificate; import java.util.Date; -import java.util.Map; - - -import gov.cdc.izgateway.security.IzgPrincipal; -import gov.cdc.izgateway.utils.X500Utils; import org.apache.commons.lang3.StringUtils; import com.fasterxml.jackson.annotation.JsonFormat; import com.fasterxml.jackson.annotation.JsonFormat.Shape; @@ -20,8 +14,6 @@ import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; -import javax.security.auth.x500.X500Principal; - /** * An endpoint describes the inbound or outbound connection to * IZ Gateway during a transaction. It is abstract to ensure diff --git a/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java b/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java index 02aacd3..2c2ddd8 100644 --- a/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java +++ b/src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java @@ -7,7 +7,6 @@ import gov.cdc.izgateway.security.IzgPrincipal; import gov.cdc.izgateway.security.principal.CertificatePrincipalProviderImpl; -import gov.cdc.izgateway.utils.X500Utils; import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; import lombok.EqualsAndHashCode; @@ -68,7 +67,7 @@ public void setPrincipal(IzgPrincipal principal) { * @param certificate */ public void setCertificate(X509Certificate certificate) { - setPrincipal(CertificatePrincipalProviderImpl.createPrincipalFromCertificate(certificate)); + setPrincipal(CertificatePrincipalProviderImpl.createPrincipalFromCertificate(certificate)); } } diff --git a/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java b/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java index 20a2eee..18e6894 100644 --- a/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java +++ b/src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java @@ -4,7 +4,6 @@ import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; -import java.sql.Timestamp; import java.util.Date; import java.util.ServiceConfigurationError; @@ -51,7 +50,6 @@ static MessageDigest getMessageDigest() { * * @param cert The certificate to compute the thumbprint. * @return A string representing the thumbprint using SHA-1 - * @throws CertificateEncodingException if the binary encoding of the certificate cannot be produced. */ static String computeThumbprint(X509Certificate cert) { if (cert == null) { diff --git a/src/main/java/gov/cdc/izgateway/model/IDestination.java b/src/main/java/gov/cdc/izgateway/model/IDestination.java index 299d4c7..cee1a1b 100644 --- a/src/main/java/gov/cdc/izgateway/model/IDestination.java +++ b/src/main/java/gov/cdc/izgateway/model/IDestination.java @@ -36,10 +36,12 @@ static class Map extends MappableEntity {} IDestinationId getId(); + @JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT) Date getMaintEnd(); String getMaintReason(); + @JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT) Date getMaintStart(); String getMsh22(); @@ -53,11 +55,13 @@ static class Map extends MappableEntity {} String getMsh6(); String getPassword(); + + Date getPassExpiry(); String getRxa11(); String getUsername(); - + void setDestUri(String destUri); void setDestVersion(String destVersion); @@ -73,7 +77,6 @@ static class Map extends MappableEntity {} void setMaintReason(String maintReason); - @JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT) void setMaintStart(Date maintStart); void setMsh22(String msh22); @@ -88,6 +91,8 @@ static class Map extends MappableEntity {} @JsonIgnore void setPassword(String password); + + void setPassExpiry(Date expiry); void setRxa11(String rxa11); diff --git a/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java b/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java index d0214b2..1d9713d 100644 --- a/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java +++ b/src/main/java/gov/cdc/izgateway/security/AccessControlRegistry.java @@ -87,8 +87,8 @@ private String[] getControllerPrefix(Class controller) { /** * Register a controller class in this registry. Does the bulk of the work for above method. * This method exists to enable testing without instantiating the class (possibly expensive). - * - * @param controller The class of the @RestController object to register + + * @param controllerClass The class of the @RestController object to register * @param prefix The prefix under which this controller is installed. */ public void register(Class controllerClass, String prefix) { @@ -150,6 +150,7 @@ private void registerPathsAndRoles(String prefix, RolesAllowed roles, MergedAnno /** * Dynamically add an access control for a path. + * @param methods The methods that apply to the path * @param path The path to add the access control for. * @param roles The roles to add the access control for. */ @@ -161,6 +162,7 @@ public void register(RequestMethod[] methods, String path, String ... roles ) { /** * Dynamically remove an access control for a path. + * @param methods The methods used with the path. * @param path The path to add the access control for. * @param roles The roles to add the access control for. */ diff --git a/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java b/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java index f4a561a..3e8c921 100644 --- a/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java +++ b/src/main/java/gov/cdc/izgateway/security/CertificatePrincipal.java @@ -1,10 +1,16 @@ package gov.cdc.izgateway.security; import lombok.Data; +import lombok.EqualsAndHashCode; import java.math.BigInteger; +/** + * A class representing a principal based on an X.509 certificate + * @author Audacious Inquiry + */ @Data +@EqualsAndHashCode(callSuper=true) public class CertificatePrincipal extends IzgPrincipal { public String getSerialNumberHex() { diff --git a/src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java b/src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java new file mode 100644 index 0000000..2b02d04 --- /dev/null +++ b/src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java @@ -0,0 +1,40 @@ +package gov.cdc.izgateway.security; + +import java.net.URLDecoder; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.nio.charset.StandardCharsets; +import org.bouncycastle.util.io.pem.PemObject; +import org.bouncycastle.util.io.pem.PemReader; +import java.io.ByteArrayInputStream; +import java.io.StringReader; +import java.security.cert.*; + +/** + * CertificateProcessor provides utility methods for processing certificates. + */ +public class CertificateProcessor { + public static X509Certificate processCertificateFromHeader(String certHeader) throws CertificateException { + certHeader = normalizeCertHeader(certHeader); + return parsePemCertificate(certHeader); + } + + private static String normalizeCertHeader(String certHeader) { + return URLDecoder.decode(certHeader.replace("+", "%2B"), StandardCharsets.UTF_8); + } + + private static X509Certificate parsePemCertificate(String pemContent) throws CertificateException { + try (PemReader pemReader = new PemReader(new StringReader(pemContent))) { + PemObject pemObject = pemReader.readPemObject(); + CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); + return (X509Certificate) certFactory.generateCertificate( + new ByteArrayInputStream(pemObject.getContent())); + } catch (Exception e) { + throw new CertificateException("Failed to parse certificate", e); + } + } + + // Private constructor to prevent instantiation + private CertificateProcessor() { + } +} diff --git a/src/main/java/gov/cdc/izgateway/security/CertificateValidator.java b/src/main/java/gov/cdc/izgateway/security/CertificateValidator.java new file mode 100644 index 0000000..ac0efea --- /dev/null +++ b/src/main/java/gov/cdc/izgateway/security/CertificateValidator.java @@ -0,0 +1,30 @@ +package gov.cdc.izgateway.security; + +import lombok.extern.slf4j.Slf4j; + +import javax.net.ssl.X509TrustManager; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +/** + * CertificateValidator provides utility methods for validating certificates. + */ +@Slf4j +public class CertificateValidator { + private final X509TrustManager trustManager; + + public CertificateValidator(X509TrustManager trustManager) { + this.trustManager = trustManager; + } + + public boolean isValid(X509Certificate cert) { + try { + cert.checkValidity(); + trustManager.checkClientTrusted(new X509Certificate[]{cert}, "TLS-client-auth"); + return true; + } catch (CertificateException e) { + log.error("Certificate validation failed", e); + return false; + } + } +} diff --git a/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java b/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java index 5ab54e5..48afe53 100644 --- a/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java +++ b/src/main/java/gov/cdc/izgateway/security/ClientTlsSupport.java @@ -102,6 +102,15 @@ public void afterPropertiesSet() throws IOException { TimeUnit.SECONDS); // specified in sections } + public TrustManager[] getTrustManagers() { + boolean reload = checkForUpdates(); + + KeyStore trustStore = loadTrustStore(reload); + TrustManager[] tm = { getTrustManager(trustStore) }; + + return tm; + } + public SSLContext getSSLContext() { boolean reload = checkForUpdates(); if (sslContext != null && !reload) { diff --git a/src/main/java/gov/cdc/izgateway/security/TrustController.java b/src/main/java/gov/cdc/izgateway/security/TrustController.java index a475c00..c5e8557 100644 --- a/src/main/java/gov/cdc/izgateway/security/TrustController.java +++ b/src/main/java/gov/cdc/izgateway/security/TrustController.java @@ -49,6 +49,7 @@ public TrustController(AccessControlRegistry registry, ClientTlsSupport tlsSuppo this.tlsConfig = tlsSupport.getConfig(); } + @SuppressWarnings({ "serial" }) public class TrustDataMap extends MappableEntity {} /** * Report on trust parameters status. diff --git a/src/main/java/gov/cdc/izgateway/security/filter/IpAddressFilter.java b/src/main/java/gov/cdc/izgateway/security/filter/IpAddressFilter.java new file mode 100644 index 0000000..c9c4d71 --- /dev/null +++ b/src/main/java/gov/cdc/izgateway/security/filter/IpAddressFilter.java @@ -0,0 +1,88 @@ +package gov.cdc.izgateway.security.filter; + +import gov.cdc.izgateway.logging.RequestContext; +import gov.cdc.izgateway.logging.markers.Markers2; +import inet.ipaddr.IPAddress; +import inet.ipaddr.IPAddressString; +import jakarta.servlet.*; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +@Slf4j +@Component +@Order(Ordered.LOWEST_PRECEDENCE) +public class IpAddressFilter implements Filter { + private List allowedSubnets = Collections.emptyList(); + private final boolean ipFilterEnabled; + + public IpAddressFilter( + @Value("${hub.security.ip-filter.allowed-cidr:}") String allowedCidr, + @Value("${hub.security.ip-filter.enabled:false}") boolean ipFilterEnabled + ) { + this.ipFilterEnabled = ipFilterEnabled; + + if (this.ipFilterEnabled) { + if (allowedCidr == null || allowedCidr.trim().isEmpty()) { + throw new IllegalStateException("IP filtering enabled, no IP CIDRs configured."); + } + + this.allowedSubnets = Arrays.stream(allowedCidr.split(",")) + .map(String::trim) + .filter(cidr -> !cidr.isEmpty()) + .map(cidr -> new IPAddressString(cidr).getAddress()) + .collect(Collectors.toList()); + + log.info("IP whitelist configured with {} CIDR blocks: {}", this.allowedSubnets.size(), allowedCidr); + + } else { + log.warn("IP filtering not enabled. All IP addresses will be allowed."); + } + } + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { + HttpServletRequest httpRequest = (HttpServletRequest) servletRequest; + + // Getting the remote address, not considering X-Forwarded-For currently + // This gets the IP from the ALB itself whereas X-Forwarded-For would get + // IP address of actual caller. + String clientIp = httpRequest.getRemoteAddr(); + + if (!ipFilterEnabled) { + filterChain.doFilter(servletRequest, servletResponse); + return; + } + + boolean allowed; + try { + IPAddress ipAddress = new IPAddressString(clientIp).getAddress(); + allowed = allowedSubnets.stream().anyMatch(subnet -> subnet.contains(ipAddress)); + + } catch (Exception e) { + // We were unable to parse the IP address, block access + log.error(Markers2.append(RequestContext.getSourceInfo()), "Unable to parse/verify IP: {}. Error: {}", clientIp, e.getMessage()); + HttpServletResponse httpResponse = (HttpServletResponse) servletResponse; + httpResponse.setStatus(HttpServletResponse.SC_FORBIDDEN); + return; + } + + if (allowed) { + filterChain.doFilter(servletRequest, servletResponse); + } else { + log.error(Markers2.append(RequestContext.getSourceInfo()), "Access denied for IP: {}. Not in any configured allowed CIDR.", clientIp); + HttpServletResponse httpResponse = (HttpServletResponse) servletResponse; + httpResponse.setStatus(HttpServletResponse.SC_FORBIDDEN); + } + } +} diff --git a/src/main/java/gov/cdc/izgateway/security/filter/SecretHeaderFilter.java b/src/main/java/gov/cdc/izgateway/security/filter/SecretHeaderFilter.java new file mode 100644 index 0000000..ff84b3a --- /dev/null +++ b/src/main/java/gov/cdc/izgateway/security/filter/SecretHeaderFilter.java @@ -0,0 +1,70 @@ +package gov.cdc.izgateway.security.filter; + +import gov.cdc.izgateway.logging.RequestContext; +import gov.cdc.izgateway.logging.markers.Markers2; +import jakarta.servlet.*; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import gov.cdc.izgateway.security.AccessControlValve; +import java.io.IOException; + +/** + * SecretHeaderFilter is a servlet filter that checks for a specific header in incoming requests. + * This is used to ensure that only requests from internal services (like ALB or WAF) are allowed to pass through. + * It is disabled by default, but can be enabled via the settings in the constructor. + * If the filter is enabled but the key or value is missing, an IllegalStateException is thrown. + */ +@Slf4j +@Component +@Order(Ordered.LOWEST_PRECEDENCE) +public class SecretHeaderFilter implements Filter { + private final boolean headerFilterEnabled; + private final String headerFilterKey; + private final String headerFilterValue; + + public SecretHeaderFilter( + @Value("${hub.security.secret-header-filter.enabled:false}") boolean headerFilterEnabled, + @Value("${hub.security.secret-header-filter.key:}") String headerFilterKey, + @Value("${hub.security.secret-header-filter.value:}") String headerFilterValue + ) { + this.headerFilterEnabled = headerFilterEnabled; + this.headerFilterKey = headerFilterKey; + this.headerFilterValue = headerFilterValue; + + if (this.headerFilterEnabled) { + if (StringUtils.isEmpty(this.headerFilterKey) || StringUtils.isEmpty(this.headerFilterValue)) { + throw new IllegalStateException("Secret header filter is enabled, but the header key or value is not set."); + } + } else { + log.warn("Secret header filter not enabled. Requests will not be filtered."); + } + } + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { + if (!headerFilterEnabled || AccessControlValve.isLocalHost(servletRequest.getRemoteAddr())) { + filterChain.doFilter(servletRequest, servletResponse); + return; + } + + // Check if the request is from an internal service (ALB, WAF) by checking the header + HttpServletRequest request = (HttpServletRequest) servletRequest; + HttpServletResponse response = (HttpServletResponse) servletResponse; + String headerValue = request.getHeader(headerFilterKey); + + if (headerValue == null || !headerValue.equals(headerFilterValue)) { + log.error(Markers2.append(RequestContext.getSourceInfo()), "Request does not contain the secret header, rejecting request."); + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + return; + } + + // Continue the filter chain + filterChain.doFilter(servletRequest, servletResponse); + } +} diff --git a/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java b/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java index 7c05460..10c6ab3 100644 --- a/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java +++ b/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java @@ -1,49 +1,79 @@ package gov.cdc.izgateway.security.principal; import gov.cdc.izgateway.principal.provider.CertificatePrincipalProvider; -import gov.cdc.izgateway.security.CertificatePrincipal; -import gov.cdc.izgateway.security.IzgPrincipal; +import gov.cdc.izgateway.security.*; import gov.cdc.izgateway.utils.X500Utils; import jakarta.servlet.http.HttpServletRequest; import lombok.extern.slf4j.Slf4j; import org.apache.catalina.Globals; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; +import javax.net.ssl.X509TrustManager; import javax.security.auth.x500.X500Principal; -import java.security.cert.X509Certificate; +import java.security.cert.*; import java.util.Map; /** - * A Principal Provider that uses certificates to identify the principal - * by common name and organization. - * - * @author Audacious Inquiry - * + * This class provides a principal from a certificate */ @Slf4j @Component public class CertificatePrincipalProviderImpl implements CertificatePrincipalProvider { + @Value("${client.ssl.certificate-header:}") + private String certHeaderKey; + + private final CertificateValidator validator; + + @Autowired + public CertificatePrincipalProviderImpl(ClientTlsSupport clientTlsSupport) { + this.validator = new CertificateValidator((X509TrustManager) clientTlsSupport.getTrustManagers()[0]); + } @Override public IzgPrincipal createPrincipalFromCertificate(HttpServletRequest request) { - X509Certificate[] certs = (X509Certificate[]) request.getAttribute(Globals.CERTIFICATES_ATTR); + X509Certificate cert = getCertificate(request); + return cert != null ? createPrincipalFromCertificate(cert) : null; + } + + /** + * Gets the certificate from two possible places. If a certificate is present in the request attribute, + * it is returned. If not, the certificate is extracted from the header. Null is returned if no certificate + * is found or if the certificate is invalid. + * @param request The request + * @return The certificate + */ + private X509Certificate getCertificate(HttpServletRequest request) { + X509Certificate cert = getCertificateFromAttribute(request); + if (cert != null) return cert; + + String certHeader = request.getHeader(certHeaderKey); + if (StringUtils.isBlank(certHeader)) return null; - if (certs == null || certs.length == 0) { + try { + cert = CertificateProcessor.processCertificateFromHeader(certHeader); + return validator.isValid(cert) ? cert : null; + } catch (CertificateException e) { + log.error("Failed to process certificate from header", e); return null; } + } - return createPrincipalFromCertificate(certs[0]); + private X509Certificate getCertificateFromAttribute(HttpServletRequest request) { + X509Certificate[] certs = (X509Certificate[]) request.getAttribute(Globals.CERTIFICATES_ATTR); + return (certs != null && certs.length > 0) ? certs[0] : null; } - /** - * Create a principal from a certificate - * @param cert The certificate - * @return The principal - */ - public static IzgPrincipal createPrincipalFromCertificate(X509Certificate cert) { + /** + * Create a principal from a certificate + * @param cert The certificate + * @return The principal + */ + public static IzgPrincipal createPrincipalFromCertificate(X509Certificate cert) { IzgPrincipal principal = new CertificatePrincipal(); - X500Principal subject = cert.getSubjectX500Principal(); + X500Principal subject = cert.getSubjectX500Principal(); Map parts = X500Utils.getParts(subject); principal.setName(parts.get(X500Utils.COMMON_NAME)); @@ -52,12 +82,13 @@ public static IzgPrincipal createPrincipalFromCertificate(X509Certificate cert) o = parts.get(X500Utils.ORGANIZATION_UNIT); } principal.setOrganization(o); - principal.setValidFrom(cert.getNotBefore()); principal.setValidTo(cert.getNotAfter()); principal.setSerialNumber(String.valueOf(cert.getSerialNumber())); principal.setIssuer(cert.getIssuerX500Principal().getName()); return principal; - } + } + } + diff --git a/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java b/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java index f2f4db9..edc6464 100644 --- a/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java +++ b/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java @@ -40,5 +40,4 @@ protected boolean validConfiguration() { log.warn("No JWT shared secret was set. JWT authentication is disabled."); return false; } - } diff --git a/src/main/java/gov/cdc/izgateway/soap/MockMessage.java b/src/main/java/gov/cdc/izgateway/soap/MockMessage.java index d1ca3c2..d1c6a20 100644 --- a/src/main/java/gov/cdc/izgateway/soap/MockMessage.java +++ b/src/main/java/gov/cdc/izgateway/soap/MockMessage.java @@ -328,11 +328,10 @@ private static String repeatString(MockMessage response, int size) { /** * Get the message response this mock is intended to return. * - * @param resp - * The HttpServeltResponse (may be used to return a non-XML body) * @param message * The SOAP message object, used again to return a non-XML body. * @return The mocked response + * @throws Fault To mock a generic fault * @throws SecurityFault * To mock a security fault. * @throws MessageTooLargeFault diff --git a/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java b/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java index f1b4bfe..d4ffd47 100644 --- a/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java +++ b/src/main/java/gov/cdc/izgateway/soap/fault/FaultSupport.java @@ -5,8 +5,6 @@ /** * Access to this interface is supported by Faults created in this package to enable * structured diagnostics and logging with Faults. - * - * @see HasFaultSupport */ public interface FaultSupport { /** Name for Summary content in fault */ @@ -52,6 +50,7 @@ public interface FaultSupport { * for automated interpretation that embodies the type of fault, the subtype and the retry strategy * to apply. It is a 3 digit code in which the first digit indicate the fault type, the second * provides the subtype, and the third, the retry strategy. + * @return The code for error */ String getCode(); diff --git a/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java b/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java index 17da14a..40b399f 100644 --- a/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java +++ b/src/main/java/gov/cdc/izgateway/soap/message/SoapMessage.java @@ -66,6 +66,7 @@ public static interface Response {} * * @param that The message to copy from. * @param schema The schema to use. + * @param isUpgradeOrSchemaChange true if this is an upgrade or schema change */ public SoapMessage(SoapMessage that, String schema, boolean isUpgradeOrSchemaChange) { // Always copy the HubHeader diff --git a/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java b/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java index 991ffae..589e328 100644 --- a/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java +++ b/src/main/java/gov/cdc/izgateway/soap/net/MessageSender.java @@ -398,18 +398,21 @@ private T readResult(Class clazz, IDestination dest, // Mark the buffer so we can reread on error. m = new HttpUrlConnectionInputMessage(con, clientConfig.getMaxBufferSize()); statusCode = m.getStatusCode(); + logDestinationCertificates(con); body = m.getBody(); m.mark(); EndPointInfo endPoint = RequestContext.getDestinationInfo(); if (ACCEPTABLE_RESPONSE_CODES.contains(statusCode)) { result = converter.read(m, endPoint); - if (result instanceof FaultMessage fm) { + if (result instanceof FaultMessage) { m.reset(); throw HubClientFault.clientThrewFault(null, dest, statusCode, body, result); } return clazz.cast(result); } else { - throw processHttpError(dest, statusCode, con.getErrorStream()); + try (InputStream errStream = con.getErrorStream()) { + throw processHttpError(dest, statusCode, errStream); + } } } catch (ClassCastException ex) { savedEx = ex; @@ -507,10 +510,10 @@ public static void logDestinationCertificates(HttpURLConnection con) { DestinationInfo destination = RequestContext.getDestinationInfo(); if (destination.isConnected() && con instanceof HttpsURLConnection conx) { try { - X509Certificate[] certs = (X509Certificate[]) conx.getServerCertificates(); - destination.setCertificate(certs[0]); destination.setCipherSuite(conx.getCipherSuite()); destination.setConnected(true); + X509Certificate[] certs = (X509Certificate[]) conx.getServerCertificates(); + destination.setCertificate(certs[0]); } catch (SSLPeerUnverifiedException | IllegalStateException ex) { // Ignore this. } diff --git a/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java b/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java index f5d7928..955f4ac 100644 --- a/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java +++ b/src/main/java/gov/cdc/izgateway/soap/net/SoapMessageReader.java @@ -32,7 +32,6 @@ import lombok.extern.slf4j.Slf4j; import javax.xml.namespace.QName; -import javax.xml.stream.Location; import javax.xml.stream.XMLStreamConstants; import javax.xml.stream.XMLStreamException; import javax.xml.stream.XMLStreamReader; diff --git a/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java b/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java index d764f12..fe9a305 100644 --- a/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java +++ b/src/main/java/gov/cdc/izgateway/utils/HL7Utils.java @@ -5,15 +5,10 @@ import java.util.TreeMap; import java.util.List; import java.io.IOException; -import java.io.InputStream; -import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; import java.util.Arrays; import java.util.function.UnaryOperator; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - import org.apache.commons.lang3.StringUtils; import lombok.Getter; @@ -106,10 +101,9 @@ private static Collection adjustNonMSHAllowedValues(Collection /** * Strip an HL7 Segment of allowed fields * - * @param b The string builder to copy the result to * @param segment The segment to strip. * @param allowedFields The set of allowed fields. - * @returns The stripped segment + * @return The stripped segment */ public static String stripSegment(String segment, Collection allowedFields) { return stripParts(segment, allowedFields, "|", HL7Utils::stripCWE); @@ -232,7 +226,7 @@ private enum ParseState { WITHIN_SEGMENT, END_SEGMENT_NAME, CHOMP_TO_DELIMITER - }; + } /** * Mash PHI in segments * This function ensures that HL7 message content potentially containing PHI is removed from the message. diff --git a/src/main/java/gov/cdc/izgateway/utils/ListConverter.java b/src/main/java/gov/cdc/izgateway/utils/ListConverter.java index 1d97cd8..d948217 100644 --- a/src/main/java/gov/cdc/izgateway/utils/ListConverter.java +++ b/src/main/java/gov/cdc/izgateway/utils/ListConverter.java @@ -10,6 +10,8 @@ * allow one to iterate over "To" items without creating a bunch of new objects. This is especially helpful * to translate between interface classes. It is used in Logging to convert from the LogBack IListEvent class * to the LogEvent class in IZ Gateway mostly to serve as a way to document logging events in Swagger. + * @param The class to convert from + * @param The class to convert to */ public final class ListConverter extends AbstractList { private final List events; diff --git a/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java b/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java index 2153dfa..901caf6 100644 --- a/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java +++ b/src/main/java/gov/cdc/izgateway/utils/XmlUtils.java @@ -26,7 +26,6 @@ import org.w3c.dom.Document; import org.w3c.dom.Node; import org.xml.sax.InputSource; -import org.xml.sax.SAXException; public class XmlUtils { private static DocumentBuilder documentBuilder = getDocumentBuilder(); diff --git a/src/test/java/gov/cdc/izgateway/security/filter/IpAddressFilterTests.java b/src/test/java/gov/cdc/izgateway/security/filter/IpAddressFilterTests.java new file mode 100644 index 0000000..4de8788 --- /dev/null +++ b/src/test/java/gov/cdc/izgateway/security/filter/IpAddressFilterTests.java @@ -0,0 +1,156 @@ +package gov.cdc.izgateway.security.filter; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +public class IpAddressFilterTests { + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain filterChain; + + @BeforeEach + void setUp() { + reset(request, response, filterChain); + } + + @Test + void testConstructorWithValidCidr() { + IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24,10.0.0.0/8", true); + assertNotNull(filter); + } + + @Test + void testConstructorWithEmptyCidrWhenEnabled() { + assertThrows(IllegalStateException.class, () -> new IpAddressFilter("", true)); + } + + @Test + void testConstructorWithNullCidrWhenEnabled() { + assertThrows(IllegalStateException.class, () -> new IpAddressFilter(null, true)); + } + + @Test + void testConstructorWithEmptyCidrWhenDisabled() { + IpAddressFilter filter = new IpAddressFilter("", false); + assertNotNull(filter); + } + + @Test + void testAllowedIpWhenFilteringEnabled() throws IOException, ServletException { + IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24", true); + when(request.getRemoteAddr()).thenReturn("192.168.1.100"); + + filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + verify(response, never()).setStatus(anyInt()); + } + + @Test + void testDisallowedIpWhenFilteringEnabled() throws IOException, ServletException { + IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24", true); + when(request.getRemoteAddr()).thenReturn("10.0.0.1"); + + filter.doFilter(request, response, filterChain); + + verify(filterChain, never()).doFilter(any(), any()); + verify(response).setStatus(HttpServletResponse.SC_FORBIDDEN); + } + + @Test + void testAnyIpAllowedWhenFilteringDisabled() throws IOException, ServletException { + IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24", false); + // IP outside the allowed range, but should not matter as filtering is disabled + when(request.getRemoteAddr()).thenReturn("10.0.0.1"); + + filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + verify(response, never()).setStatus(anyInt()); + } + + @Test + void testMultipleCidrRanges() throws IOException, ServletException { + IpAddressFilter filter = new IpAddressFilter("192.168.1.0/24,10.0.0.0/8", true); + + // Test first range + when(request.getRemoteAddr()).thenReturn("192.168.1.100"); + filter.doFilter(request, response, filterChain); + verify(filterChain).doFilter(request, response); + reset(request, response, filterChain); + + // Test second range + when(request.getRemoteAddr()).thenReturn("10.10.10.10"); + filter.doFilter(request, response, filterChain); + verify(filterChain).doFilter(request, response); + reset(request, response, filterChain); + + // Test outside range + when(request.getRemoteAddr()).thenReturn("172.16.0.1"); + filter.doFilter(request, response, filterChain); + verify(filterChain, never()).doFilter(any(), any()); + verify(response).setStatus(HttpServletResponse.SC_FORBIDDEN); + } + + @Test + void testIpv6Localhost() throws IOException, ServletException { + IpAddressFilter filter = new IpAddressFilter("127.0.0.1/32,::1/128", true); + + when(request.getRemoteAddr()).thenReturn("0:0:0:0:0:0:0:1"); + filter.doFilter(request, response, filterChain); + verify(filterChain).doFilter(request, response); + verify(response, never()).setStatus(anyInt()); + + reset(request, response, filterChain); + } + + @Test + void testNonLocalIpv6Address() throws IOException, ServletException { + // Setup with IPv4 CIDR only + IpAddressFilter filter = new IpAddressFilter("::1/128", true); + + // Test with a non-localhost IPv6 address + when(request.getRemoteAddr()).thenReturn("2001:0db8:85a3:0000:0000:8a2e:0370:7334"); + filter.doFilter(request, response, filterChain); + + // Should be denied as the IPv6 address isn't in allowed CIDRs + verify(filterChain, never()).doFilter(any(), any()); + verify(response).setStatus(HttpServletResponse.SC_FORBIDDEN); + } + + @Test + void testAllowedIpv6Address() throws IOException, ServletException { + // Setup with both IPv4 and IPv6 CIDRs + IpAddressFilter filter = new IpAddressFilter("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128", true); + + // Test with an allowed IPv6 address + when(request.getRemoteAddr()).thenReturn("2001:0db8:85a3:0000:0000:8a2e:0370:7334"); + filter.doFilter(request, response, filterChain); + + // Should be allowed as the IPv6 address is in allowed CIDR + verify(filterChain).doFilter(request, response); + verify(response, never()).setStatus(anyInt()); + } +} diff --git a/src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java b/src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java new file mode 100644 index 0000000..a7b7fff --- /dev/null +++ b/src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java @@ -0,0 +1,86 @@ +// File: src/test/java/gov/cdc/izgateway/security/filter/SecretHeaderFilterTests.java + +package gov.cdc.izgateway.security.filter; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.io.IOException; + +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +public class SecretHeaderFilterTests { + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain filterChain; + + private SecretHeaderFilter filter; + + @BeforeEach + void setUp() { + // Initialize the filter with test values + filter = new SecretHeaderFilter(true, "x-alb-secret", "secret-value"); + } + + @Test + void testFilterDisabled() throws IOException, ServletException { + // Initialize the filter with disabled state + filter = new SecretHeaderFilter(false, "", ""); + + filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + verify(response, never()).setStatus(anyInt()); + } + + @Test + void testFilterEnabledWithoutKeyOrValue() { + // Expect IllegalStateException when filter is enabled but key or value is missing + assertThrows(IllegalStateException.class, () -> new SecretHeaderFilter(true, "", "")); + } + + @Test + void testRequestWithoutHeader() throws IOException, ServletException { + when(request.getHeader("x-alb-secret")).thenReturn(null); + + filter.doFilter(request, response, filterChain); + + verify(filterChain, never()).doFilter(any(), any()); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + void testRequestWithInvalidHeader() throws IOException, ServletException { + when(request.getHeader("x-alb-secret")).thenReturn("invalid-value"); + + filter.doFilter(request, response, filterChain); + + verify(filterChain, never()).doFilter(any(), any()); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + void testRequestWithValidHeader() throws IOException, ServletException { + when(request.getHeader("x-alb-secret")).thenReturn("secret-value"); + + filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + verify(response, never()).setStatus(anyInt()); + } +} diff --git a/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java b/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java index 50a464e..515ee68 100644 --- a/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java +++ b/src/test/java/test/gov/cdc/izgateway/utils/TestHL7Utils.java @@ -2,8 +2,6 @@ import static org.junit.jupiter.api.Assertions.*; -import org.apache.commons.text.StringEscapeUtils; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource;