Skip to content

Commit f82a2a1

Browse files
fix: Change encoding from HMAC to AES when setting secret key (#26487)
1 parent ee1a9f7 commit f82a2a1

File tree

4 files changed

+137
-13
lines changed

4 files changed

+137
-13
lines changed

presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import com.nimbusds.jose.EncryptionMethod;
1818
import com.nimbusds.jose.JOSEException;
1919
import com.nimbusds.jose.JWEAlgorithm;
20+
import com.nimbusds.jose.JWEDecrypter;
21+
import com.nimbusds.jose.JWEEncrypter;
2022
import com.nimbusds.jose.JWEHeader;
2123
import com.nimbusds.jose.JWEObject;
2224
import com.nimbusds.jose.KeyLengthException;
@@ -49,8 +51,8 @@
4951
public class JweTokenSerializer
5052
implements TokenPairSerializer
5153
{
52-
private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
53-
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
54+
private final JWEHeader encryptionHeader;
55+
5456
private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
5557
private static final String ACCESS_TOKEN_KEY = "access_token";
5658
private static final String EXPIRATION_TIME_KEY = "expiration_time";
@@ -61,8 +63,8 @@ public class JweTokenSerializer
6163
private final String audience;
6264
private final Duration tokenExpiration;
6365
private final JwtParser parser;
64-
private final AESEncrypter jweEncrypter;
65-
private final AESDecrypter jweDecrypter;
66+
private final JWEEncrypter jweEncrypter;
67+
private final JWEDecrypter jweDecrypter;
6668
private final String principalField;
6769

6870
public JweTokenSerializer(
@@ -84,6 +86,7 @@ public JweTokenSerializer(
8486
this.audience = requireNonNull(audience, "issuer is null");
8587
this.clock = requireNonNull(clock, "clock is null");
8688
this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null");
89+
this.encryptionHeader = createEncryptionHeader(secretKey);
8790

8891
this.parser = newJwtParserBuilder()
8992
.setClock(() -> Date.from(clock.instant()))
@@ -93,11 +96,26 @@ public JweTokenSerializer(
9396
.build();
9497
}
9598

99+
private JWEHeader createEncryptionHeader(SecretKey key)
100+
{
101+
int keyLength = key.getEncoded().length;
102+
switch (keyLength) {
103+
case 16:
104+
return new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM);
105+
case 24:
106+
return new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM);
107+
case 32:
108+
return new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM);
109+
default:
110+
throw new IllegalArgumentException(
111+
String.format("Secret key size must be either 16, 24 or 32 bytes but was %d", keyLength));
112+
}
113+
}
114+
96115
@Override
97116
public TokenPair deserialize(String token)
98117
{
99118
requireNonNull(token, "token is null");
100-
101119
try {
102120
JWEObject jwe = JWEObject.parse(token);
103121
jwe.decrypt(jweDecrypter);
@@ -139,9 +157,7 @@ public String serialize(TokenPair tokenPair)
139157
.compressWith(COMPRESSION_CODEC);
140158

141159
try {
142-
JWEObject jwe = new JWEObject(
143-
new JWEHeader(ALGORITHM, ENCRYPTION_METHOD),
144-
new Payload(jwt.compact()));
160+
JWEObject jwe = new JWEObject(encryptionHeader, new Payload(jwt.compact()));
145161
jwe.encrypt(jweEncrypter);
146162
return jwe.serialize();
147163
}

presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import com.facebook.airlift.configuration.ConfigSecuritySensitive;
1919
import com.facebook.airlift.units.Duration;
2020
import io.jsonwebtoken.io.Decoders;
21-
import io.jsonwebtoken.security.Keys;
2221
import jakarta.validation.constraints.NotEmpty;
2322

2423
import javax.crypto.SecretKey;
24+
import javax.crypto.spec.SecretKeySpec;
2525

2626
import static com.google.common.base.Strings.isNullOrEmpty;
2727
import static java.util.concurrent.TimeUnit.HOURS;
@@ -85,7 +85,7 @@ public RefreshTokensConfig setSecretKey(String key)
8585
return this;
8686
}
8787

88-
secretKey = Keys.hmacShaKeyFor(Decoders.BASE64.decode(key));
88+
secretKey = new SecretKeySpec(Decoders.BASE64.decode(key), "AES");
8989
return this;
9090
}
9191

presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,10 @@ public Optional<String> getRefreshToken()
8787
{
8888
return refreshToken;
8989
}
90+
91+
public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken)
92+
{
93+
return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken));
94+
}
9095
}
9196
}

presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,41 @@
1818
import com.nimbusds.jose.KeyLengthException;
1919
import io.jsonwebtoken.ExpiredJwtException;
2020
import io.jsonwebtoken.Jwts;
21+
import org.testng.annotations.DataProvider;
2122
import org.testng.annotations.Test;
2223

2324
import java.net.URI;
2425
import java.security.GeneralSecurityException;
26+
import java.security.NoSuchAlgorithmException;
27+
import java.security.SecureRandom;
2528
import java.time.Clock;
2629
import java.time.Instant;
2730
import java.time.ZoneId;
2831
import java.time.ZonedDateTime;
32+
import java.util.Base64;
2933
import java.util.Calendar;
3034
import java.util.Date;
3135
import java.util.Map;
3236
import java.util.Optional;
37+
import java.util.Random;
3338

3439
import static com.facebook.airlift.units.Duration.succinctDuration;
3540
import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.accessAndRefreshTokens;
41+
import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.withAccessAndRefreshTokens;
3642
import static java.time.temporal.ChronoUnit.MILLIS;
3743
import static java.util.concurrent.TimeUnit.MINUTES;
3844
import static java.util.concurrent.TimeUnit.SECONDS;
3945
import static org.assertj.core.api.Assertions.assertThat;
4046
import static org.assertj.core.api.Assertions.assertThatThrownBy;
47+
import static org.testng.Assert.assertEquals;
4148

4249
public class TestJweTokenSerializer
4350
{
4451
@Test
4552
public void testSerialization()
4653
throws Exception
4754
{
48-
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS));
55+
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), randomEncodedSecret());
4956

5057
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
5158
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
@@ -56,14 +63,75 @@ public void testSerialization()
5663
assertThat(deserializedTokenPair.getRefreshToken()).isEqualTo(Optional.of("refresh_token"));
5764
}
5865

66+
@Test(dataProvider = "wrongSecretsProvider")
67+
public void testDeserializationWithWrongSecret(String encryptionSecret, String decryptionSecret)
68+
{
69+
assertThatThrownBy(() -> assertRoundTrip(Optional.ofNullable(encryptionSecret), Optional.ofNullable(decryptionSecret)))
70+
.isInstanceOf(IllegalArgumentException.class)
71+
.hasMessage("Decryption failed")
72+
.hasStackTraceContaining("Tag mismatch!");
73+
}
74+
75+
@DataProvider
76+
public Object[][] wrongSecretsProvider()
77+
{
78+
return new Object[][] {
79+
{randomEncodedSecret(), randomEncodedSecret()},
80+
{randomEncodedSecret(16), randomEncodedSecret(24)},
81+
{null, null}, // This will generate two different secret keys
82+
{null, randomEncodedSecret()},
83+
{randomEncodedSecret(), null}
84+
};
85+
}
86+
87+
@Test
88+
public void testSerializationDeserializationRoundTripWithDifferentKeyLengths()
89+
throws Exception
90+
{
91+
for (int keySize : new int[] {16, 24, 32}) {
92+
String secret = randomEncodedSecret(keySize);
93+
assertRoundTrip(secret, secret);
94+
}
95+
}
96+
97+
@Test
98+
public void testSerializationFailsWithWrongKeySize()
99+
{
100+
for (int wrongKeySize : new int[] {8, 64, 128}) {
101+
String tooShortSecret = randomEncodedSecret(wrongKeySize);
102+
assertThatThrownBy(() -> assertRoundTrip(tooShortSecret, tooShortSecret))
103+
.hasStackTraceContaining("The Key Encryption Key length must be 128 bits (16 bytes), 192 bits (24 bytes) or 256 bits (32 bytes)");
104+
}
105+
}
106+
107+
private void assertRoundTrip(String serializerSecret, String deserializerSecret)
108+
throws Exception
109+
{
110+
assertRoundTrip(Optional.of(serializerSecret), Optional.of(deserializerSecret));
111+
}
112+
113+
private void assertRoundTrip(Optional<String> serializerSecret, Optional<String> deserializerSecret)
114+
throws Exception
115+
{
116+
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), serializerSecret);
117+
JweTokenSerializer deserializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), deserializerSecret);
118+
Date expiration = new Calendar.Builder().setDate(2023, 6, 22).build().getTime();
119+
TokenPair tokenPair = withAccessAndRefreshTokens(randomEncodedSecret(), expiration, randomEncodedSecret());
120+
TokenPair postSerPair = deserializer.deserialize(serializer.serialize(tokenPair));
121+
assertEquals(tokenPair.getAccessToken(), postSerPair.getAccessToken());
122+
assertEquals(tokenPair.getRefreshToken(), postSerPair.getRefreshToken());
123+
assertEquals(tokenPair.getExpiration(), postSerPair.getExpiration());
124+
}
125+
59126
@Test
60127
public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension()
61128
throws Exception
62129
{
63130
TestingClock clock = new TestingClock();
64131
JweTokenSerializer serializer = tokenSerializer(
65132
clock,
66-
succinctDuration(12, MINUTES));
133+
succinctDuration(12, MINUTES),
134+
randomEncodedSecret());
67135
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
68136
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
69137
clock.advanceBy(succinctDuration(10, MINUTES));
@@ -82,7 +150,8 @@ public void testTokenDeserializationAfterTimeoutAndExpirationExtension()
82150

83151
JweTokenSerializer serializer = tokenSerializer(
84152
clock,
85-
succinctDuration(12, MINUTES));
153+
succinctDuration(12, MINUTES),
154+
randomEncodedSecret());
86155
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
87156
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
88157

@@ -104,6 +173,40 @@ private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration
104173
tokenExpiration);
105174
}
106175

176+
private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, String encodedSecretKey)
177+
throws GeneralSecurityException, KeyLengthException
178+
{
179+
return tokenSerializer(clock, tokenExpiration, Optional.of(encodedSecretKey));
180+
}
181+
182+
private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, Optional<String> secretKey)
183+
throws NoSuchAlgorithmException, KeyLengthException
184+
{
185+
RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig();
186+
secretKey.ifPresent(refreshTokensConfig::setSecretKey);
187+
return new JweTokenSerializer(
188+
refreshTokensConfig,
189+
new Oauth2ClientStub(),
190+
"presto_coordinator_test_version",
191+
"presto_coordinator",
192+
"sub",
193+
clock,
194+
tokenExpiration);
195+
}
196+
197+
private static String randomEncodedSecret()
198+
{
199+
return randomEncodedSecret(24);
200+
}
201+
202+
private static String randomEncodedSecret(int length)
203+
{
204+
Random random = new SecureRandom();
205+
final byte[] buffer = new byte[length];
206+
random.nextBytes(buffer);
207+
return Base64.getEncoder().encodeToString(buffer);
208+
}
209+
107210
static class Oauth2ClientStub
108211
implements OAuth2Client
109212
{

0 commit comments

Comments
 (0)