Skip to content

Improve internal communication authentication #26000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,45 @@
package io.trino.server;

import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
import com.google.inject.Inject;
import io.airlift.http.client.HttpRequestFilter;
import io.airlift.http.client.Request;
import io.airlift.log.Logger;
import io.airlift.node.NodeInfo;
import io.jsonwebtoken.JwtException;
import io.jsonwebtoken.JwtParser;
import io.trino.server.security.InternalPrincipal;
import io.trino.server.security.SecurityConfig;
import io.trino.spi.security.Identity;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Response;

import javax.crypto.SecretKey;

import java.time.Instant;
import java.time.ZonedDateTime;
import java.util.Date;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.net.URI;
import java.util.Base64;

import static io.airlift.http.client.Request.Builder.fromRequest;
import static io.jsonwebtoken.security.Keys.hmacShaKeyFor;
import static io.trino.server.ServletSecurityUtils.setAuthenticatedIdentity;
import static io.trino.server.security.jwt.JwtUtil.newJwtBuilder;
import static io.trino.server.security.jwt.JwtUtil.newJwtParserBuilder;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED;
import static java.lang.Long.parseLong;
import static java.lang.System.currentTimeMillis;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.time.temporal.ChronoUnit.MINUTES;
import static java.util.Objects.requireNonNull;

public class InternalAuthenticationManager
implements HttpRequestFilter
{
private static final Logger log = Logger.get(InternalAuthenticationManager.class);
private static final Supplier<Instant> DEFAULT_EXPIRATION_SUPPLIER = () -> ZonedDateTime.now().plusMinutes(6).toInstant();
// Leave a 5 minute buffer to allow for clock skew and GC pauses
private static final Function<Instant, Instant> TOKEN_REUSE_THRESHOLD = instant -> instant.minus(5, MINUTES);

private static final String TRINO_INTERNAL_BEARER = "X-Trino-Internal-Bearer";
private static final int MAX_REQUEST_AGE_SECONDS = 300; // 5 minutes

private static final String TRINO_INTERNAL_SIGNATURE = "X-Trino-Internal-Signature";
private static final String TRINO_INTERNAL_NODE_ID = "X-Trino-Internal-Node-Id";
private static final String TRINO_INTERNAL_REQUEST_TIMESTAMP = "X-Trino-Internal-Timestamp";

private final SecretKey hmac;
private final HashFunction hashing;
private final String nodeId;
private final JwtParser jwtParser;
private final AtomicReference<InternalToken> currentToken;

@Inject
public InternalAuthenticationManager(InternalCommunicationConfig internalCommunicationConfig, SecurityConfig securityConfig, NodeInfo nodeInfo)
Expand Down Expand Up @@ -90,93 +81,72 @@ public InternalAuthenticationManager(String sharedSecret, String nodeId)
{
requireNonNull(sharedSecret, "sharedSecret is null");
requireNonNull(nodeId, "nodeId is null");
this.hmac = hmacShaKeyFor(Hashing.sha256().hashString(sharedSecret, UTF_8).asBytes());
this.hashing = Hashing.hmacSha256(sharedSecret.getBytes(UTF_8));
this.nodeId = nodeId;
this.jwtParser = newJwtParserBuilder().verifyWith(hmac).build();
this.currentToken = new AtomicReference<>(createJwt());
}

public static boolean isInternalRequest(ContainerRequestContext request)
{
return request.getHeaders().getFirst(TRINO_INTERNAL_BEARER) != null;
return request.getHeaders().getFirst(TRINO_INTERNAL_SIGNATURE) != null;
}

public void handleInternalRequest(ContainerRequestContext request)
{
String subject;
try {
subject = parseJwt(request.getHeaders().getFirst(TRINO_INTERNAL_BEARER));
}
catch (JwtException e) {
log.error(e, "Internal authentication failed");
String nodeId = getRequiredHeader(request, TRINO_INTERNAL_NODE_ID);
String signature = getRequiredHeader(request, TRINO_INTERNAL_SIGNATURE);

long requestTimestampMillis = parseLong(getRequiredHeader(request, TRINO_INTERNAL_REQUEST_TIMESTAMP));

if (!signature.equals(signature(nodeId, requestTimestampMillis, request.getUriInfo().getRequestUri()))) {
log.error("Internal authentication failed: request signature mismatch");
request.abortWith(Response.status(UNAUTHORIZED)
.type(TEXT_PLAIN_TYPE.toString())
.build());
return;
}
catch (RuntimeException e) {
throw new RuntimeException("Authentication error", e);

if (requestTimestampMillis < 0 || requestTimestampMillis + MAX_REQUEST_AGE_SECONDS * 1000 < currentTimeMillis()) {
log.error("Internal authentication failed: request expired at %s", requestTimestampMillis);
request.abortWith(Response.status(UNAUTHORIZED)
.type(TEXT_PLAIN_TYPE.toString())
.build());
return;
}

Identity identity = Identity.forUser("<internal>")
.withPrincipal(new InternalPrincipal(subject))
.withPrincipal(new InternalPrincipal(nodeId))
.build();
setAuthenticatedIdentity(request, identity);
}

private String signature(String nodeId, long requestTimestampMillis, URI uri)
{
return Base64.getEncoder()
.encodeToString(hashing.newHasher()
.putUnencodedChars(nodeId)
.putUnencodedChars(uri.toString())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the full URI is maintained perfectly across the HTTP protocol. I'd think about using just the path and maybe the query string.

.putLong(requestTimestampMillis)
.hash()
.asBytes());
}

@Override
public Request filterRequest(Request request)
{
long now = currentTimeMillis();
return fromRequest(request)
.addHeader(TRINO_INTERNAL_BEARER, getOrGenerateJwt())
.addHeader(TRINO_INTERNAL_NODE_ID, nodeId)
.addHeader(TRINO_INTERNAL_REQUEST_TIMESTAMP, Long.toString(now))
.addHeader(TRINO_INTERNAL_SIGNATURE, signature(nodeId, now, request.getUri()))
.build();
}

private String getOrGenerateJwt()
private static String getRequiredHeader(ContainerRequestContext request, String headerName)
{
InternalToken token = currentToken.get();
if (token.isExpired()) {
InternalToken newToken = createJwt();
if (currentToken.compareAndSet(token, newToken)) {
token = newToken;
}
else {
// Another thread already generated a new token
token = currentToken.get();
}
}
return token.token();
}

private InternalToken createJwt()
{
Instant expiration = DEFAULT_EXPIRATION_SUPPLIER.get();
return new InternalToken(expiration, newJwtBuilder()
.signWith(hmac)
.subject(nodeId)
.expiration(Date.from(expiration))
.compact());
}

private String parseJwt(String jwt)
{
return jwtParser
.parseSignedClaims(jwt)
.getPayload()
.getSubject();
}

private record InternalToken(Instant expiration, String token)
{
public InternalToken
{
expiration = TOKEN_REUSE_THRESHOLD.apply(requireNonNull(expiration, "expiration is null"));
requireNonNull(token, "token is null");
}

public boolean isExpired()
{
return Instant.now().isAfter(expiration);
String headerValue = request.getHeaderString(headerName);
if (headerValue == null || headerValue.isEmpty()) {
throw new ForbiddenException("Missing required authentication header: " + headerName);
}
return headerValue;
}
}
Loading