Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR to bring latest changes from develop into Release_v2.2.0-branch #27

Merged
merged 5 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
</dependency>
<dependency>
<groupId>com.github.seancfoley</groupId>
<artifactId>ipaddress</artifactId>
<version>5.5.1</version>
</dependency>
</dependencies>
<build>
<pluginManagement>
Expand Down
28 changes: 1 addition & 27 deletions src/main/java/gov/cdc/izgateway/logging/LoggingValve.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,11 @@ public class LoggingValve extends LoggingValveBase implements EventCreator {
Collections.unmodifiableList(Arrays.asList(EVENT_ID, SESSION_ID, METHOD, IP_ADDRESS, REQUEST_URI, COMMON_NAME));

private static final String REST_ADS = "/rest/ads";
// Keep mappings for at most one minute.
private static final int MAX_AGE = 60 * 1000;
@SuppressWarnings("unused")
@SuppressWarnings("unused")
private ScheduledFuture<?> adsMonitor =
Executors.newSingleThreadScheduledExecutor(r -> new Thread(r, "ADS Monitor"))
.scheduleAtFixedRate(this::monitorADSRequests, 0, 15, TimeUnit.SECONDS);
private static final ConcurrentHashMap<Request, String> adsRequests = new ConcurrentHashMap<>();
private Map<String, LoggingValveEvent> map = new LinkedHashMap<>();

private static class LoggingValveEvent implements Event {
private final String id;
private Date date;
private int refs;

private LoggingValveEvent(String id, Date date) {
refs = 1;
this.id = id;
this.date = date;
}

@Override
public String getId() {
return id;
}

@Override
public Date getDate() {
return date;
}
}

@Autowired
public LoggingValve(PrincipalService principalService) {
this.principalService = principalService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,6 @@ public void setResponseEchoBack(String val) {
setResponsePayloadSize(StringUtils.length(val));
}

private static final Pattern TEST_MESSAGE_PATTERN
= Pattern.compile("^(([A-Z]+(AIRA|TEST)\\^[A-Z]+(AIRA|TEST))|([A-Z]*IZG[A-Z]*))\\^", Pattern.CASE_INSENSITIVE);
/**
* Indicates if the message matches known test patterns
* @param message The message
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
package gov.cdc.izgateway.logging.info;

import java.io.Serializable;
import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.Map;


import gov.cdc.izgateway.security.IzgPrincipal;
import gov.cdc.izgateway.utils.X500Utils;
import org.apache.commons.lang3.StringUtils;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.annotation.JsonFormat.Shape;
Expand All @@ -20,8 +14,6 @@
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;

import javax.security.auth.x500.X500Principal;

/**
* An endpoint describes the inbound or outbound connection to
* IZ Gateway during a transaction. It is abstract to ensure
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/gov/cdc/izgateway/logging/info/SourceInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import gov.cdc.izgateway.security.IzgPrincipal;
import gov.cdc.izgateway.security.principal.CertificatePrincipalProviderImpl;
import gov.cdc.izgateway.utils.X500Utils;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
Expand Down Expand Up @@ -68,7 +67,7 @@ public void setPrincipal(IzgPrincipal principal) {
* @param certificate
*/
public void setCertificate(X509Certificate certificate) {
setPrincipal(CertificatePrincipalProviderImpl.createPrincipalFromCertificate(certificate));
setPrincipal(CertificatePrincipalProviderImpl.createPrincipalFromCertificate(certificate));
}

}
2 changes: 0 additions & 2 deletions src/main/java/gov/cdc/izgateway/model/ICertificateStatus.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.sql.Timestamp;
import java.util.Date;
import java.util.ServiceConfigurationError;

Expand Down Expand Up @@ -51,7 +50,6 @@ static MessageDigest getMessageDigest() {
*
* @param cert The certificate to compute the thumbprint.
* @return A string representing the thumbprint using SHA-1
* @throws CertificateEncodingException if the binary encoding of the certificate cannot be produced.
*/
static String computeThumbprint(X509Certificate cert) {
if (cert == null) {
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/gov/cdc/izgateway/model/IDestination.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ static class Map extends MappableEntity<IDestination> {}

IDestinationId getId();

@JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT)
Date getMaintEnd();

String getMaintReason();

@JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT)
Date getMaintStart();

String getMsh22();
Expand All @@ -53,11 +55,13 @@ static class Map extends MappableEntity<IDestination> {}
String getMsh6();

String getPassword();

Date getPassExpiry();

String getRxa11();

String getUsername();

void setDestUri(String destUri);

void setDestVersion(String destVersion);
Expand All @@ -73,7 +77,6 @@ static class Map extends MappableEntity<IDestination> {}

void setMaintReason(String maintReason);

@JsonFormat(shape=Shape.STRING, pattern=Constants.TIMESTAMP_FORMAT)
void setMaintStart(Date maintStart);

void setMsh22(String msh22);
Expand All @@ -88,6 +91,8 @@ static class Map extends MappableEntity<IDestination> {}

@JsonIgnore
void setPassword(String password);

void setPassExpiry(Date expiry);

void setRxa11(String rxa11);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ private String[] getControllerPrefix(Class<?> controller) {
/**
* Register a controller class in this registry. Does the bulk of the work for above method.
* This method exists to enable testing without instantiating the class (possibly expensive).
*
* @param controller The class of the @RestController object to register

* @param controllerClass The class of the @RestController object to register
* @param prefix The prefix under which this controller is installed.
*/
public void register(Class<?> controllerClass, String prefix) {
Expand Down Expand Up @@ -150,6 +150,7 @@ private void registerPathsAndRoles(String prefix, RolesAllowed roles, MergedAnno

/**
* Dynamically add an access control for a path.
* @param methods The methods that apply to the path
* @param path The path to add the access control for.
* @param roles The roles to add the access control for.
*/
Expand All @@ -161,6 +162,7 @@ public void register(RequestMethod[] methods, String path, String ... roles ) {

/**
* Dynamically remove an access control for a path.
* @param methods The methods used with the path.
* @param path The path to add the access control for.
* @param roles The roles to add the access control for.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package gov.cdc.izgateway.security;

import lombok.Data;
import lombok.EqualsAndHashCode;

import java.math.BigInteger;

/**
* A class representing a principal based on an X.509 certificate
* @author Audacious Inquiry
*/
@Data
@EqualsAndHashCode(callSuper=true)
public class CertificatePrincipal extends IzgPrincipal {

public String getSerialNumberHex() {
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/gov/cdc/izgateway/security/CertificateProcessor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package gov.cdc.izgateway.security;

import java.net.URLDecoder;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.nio.charset.StandardCharsets;
import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemReader;
import java.io.ByteArrayInputStream;
import java.io.StringReader;
import java.security.cert.*;

/**
* CertificateProcessor provides utility methods for processing certificates.
*/
public class CertificateProcessor {
public static X509Certificate processCertificateFromHeader(String certHeader) throws CertificateException {
certHeader = normalizeCertHeader(certHeader);
return parsePemCertificate(certHeader);
}

private static String normalizeCertHeader(String certHeader) {
return URLDecoder.decode(certHeader.replace("+", "%2B"), StandardCharsets.UTF_8);
}

private static X509Certificate parsePemCertificate(String pemContent) throws CertificateException {
try (PemReader pemReader = new PemReader(new StringReader(pemContent))) {
PemObject pemObject = pemReader.readPemObject();
CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
return (X509Certificate) certFactory.generateCertificate(
new ByteArrayInputStream(pemObject.getContent()));
} catch (Exception e) {
throw new CertificateException("Failed to parse certificate", e);
}
}

// Private constructor to prevent instantiation
private CertificateProcessor() {
}
}
30 changes: 30 additions & 0 deletions src/main/java/gov/cdc/izgateway/security/CertificateValidator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package gov.cdc.izgateway.security;

import lombok.extern.slf4j.Slf4j;

import javax.net.ssl.X509TrustManager;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;

/**
* CertificateValidator provides utility methods for validating certificates.
*/
@Slf4j
public class CertificateValidator {
private final X509TrustManager trustManager;

public CertificateValidator(X509TrustManager trustManager) {
this.trustManager = trustManager;
}

public boolean isValid(X509Certificate cert) {
try {
cert.checkValidity();
trustManager.checkClientTrusted(new X509Certificate[]{cert}, "TLS-client-auth");
return true;
} catch (CertificateException e) {
log.error("Certificate validation failed", e);
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ public void afterPropertiesSet() throws IOException {
TimeUnit.SECONDS); // specified in sections
}

public TrustManager[] getTrustManagers() {
boolean reload = checkForUpdates();

KeyStore trustStore = loadTrustStore(reload);
TrustManager[] tm = { getTrustManager(trustStore) };

return tm;
}

public SSLContext getSSLContext() {
boolean reload = checkForUpdates();
if (sslContext != null && !reload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public TrustController(AccessControlRegistry registry, ClientTlsSupport tlsSuppo
this.tlsConfig = tlsSupport.getConfig();
}

@SuppressWarnings({ "serial" })
public class TrustDataMap extends MappableEntity<TrustData> {}
/**
* Report on trust parameters status.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package gov.cdc.izgateway.security.filter;

import gov.cdc.izgateway.logging.RequestContext;
import gov.cdc.izgateway.logging.markers.Markers2;
import inet.ipaddr.IPAddress;
import inet.ipaddr.IPAddressString;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Slf4j
@Component
@Order(Ordered.LOWEST_PRECEDENCE)
public class IpAddressFilter implements Filter {
private List<IPAddress> allowedSubnets = Collections.emptyList();
private final boolean ipFilterEnabled;

public IpAddressFilter(
@Value("${hub.security.ip-filter.allowed-cidr:}") String allowedCidr,
@Value("${hub.security.ip-filter.enabled:false}") boolean ipFilterEnabled
) {
this.ipFilterEnabled = ipFilterEnabled;

if (this.ipFilterEnabled) {
if (allowedCidr == null || allowedCidr.trim().isEmpty()) {
throw new IllegalStateException("IP filtering enabled, no IP CIDRs configured.");
}

this.allowedSubnets = Arrays.stream(allowedCidr.split(","))
.map(String::trim)
.filter(cidr -> !cidr.isEmpty())
.map(cidr -> new IPAddressString(cidr).getAddress())
.collect(Collectors.toList());

log.info("IP whitelist configured with {} CIDR blocks: {}", this.allowedSubnets.size(), allowedCidr);

} else {
log.warn("IP filtering not enabled. All IP addresses will be allowed.");
}
}

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) servletRequest;

// Getting the remote address, not considering X-Forwarded-For currently
// This gets the IP from the ALB itself whereas X-Forwarded-For would get
// IP address of actual caller.
String clientIp = httpRequest.getRemoteAddr();

if (!ipFilterEnabled) {
filterChain.doFilter(servletRequest, servletResponse);
return;
}

boolean allowed;
try {
IPAddress ipAddress = new IPAddressString(clientIp).getAddress();
allowed = allowedSubnets.stream().anyMatch(subnet -> subnet.contains(ipAddress));

} catch (Exception e) {
// We were unable to parse the IP address, block access
log.error(Markers2.append(RequestContext.getSourceInfo()), "Unable to parse/verify IP: {}. Error: {}", clientIp, e.getMessage());
HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
httpResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
return;
}

if (allowed) {
filterChain.doFilter(servletRequest, servletResponse);
} else {
log.error(Markers2.append(RequestContext.getSourceInfo()), "Access denied for IP: {}. Not in any configured allowed CIDR.", clientIp);
HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
httpResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
}
}
}
Loading