1818import com .nimbusds .jose .KeyLengthException ;
1919import io .jsonwebtoken .ExpiredJwtException ;
2020import io .jsonwebtoken .Jwts ;
21+ import org .testng .annotations .DataProvider ;
2122import org .testng .annotations .Test ;
2223
2324import java .net .URI ;
2425import java .security .GeneralSecurityException ;
26+ import java .security .NoSuchAlgorithmException ;
27+ import java .security .SecureRandom ;
2528import java .time .Clock ;
2629import java .time .Instant ;
2730import java .time .ZoneId ;
2831import java .time .ZonedDateTime ;
32+ import java .util .Base64 ;
2933import java .util .Calendar ;
3034import java .util .Date ;
3135import java .util .Map ;
3236import java .util .Optional ;
37+ import java .util .Random ;
3338
3439import static com .facebook .airlift .units .Duration .succinctDuration ;
3540import static com .facebook .presto .server .security .oauth2 .TokenPairSerializer .TokenPair .accessAndRefreshTokens ;
41+ import static com .facebook .presto .server .security .oauth2 .TokenPairSerializer .TokenPair .withAccessAndRefreshTokens ;
3642import static java .time .temporal .ChronoUnit .MILLIS ;
3743import static java .util .concurrent .TimeUnit .MINUTES ;
3844import static java .util .concurrent .TimeUnit .SECONDS ;
3945import static org .assertj .core .api .Assertions .assertThat ;
4046import static org .assertj .core .api .Assertions .assertThatThrownBy ;
47+ import static org .testng .Assert .assertEquals ;
4148
4249public class TestJweTokenSerializer
4350{
4451 @ Test
4552 public void testSerialization ()
4653 throws Exception
4754 {
48- JweTokenSerializer serializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ));
55+ JweTokenSerializer serializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ), randomEncodedSecret () );
4956
5057 Date expiration = new Calendar .Builder ().setDate (2022 , 6 , 22 ).build ().getTime ();
5158 String serializedTokenPair = serializer .serialize (accessAndRefreshTokens ("access_token" , expiration , "refresh_token" ));
@@ -56,14 +63,75 @@ public void testSerialization()
5663 assertThat (deserializedTokenPair .getRefreshToken ()).isEqualTo (Optional .of ("refresh_token" ));
5764 }
5865
66+ @ Test (dataProvider = "wrongSecretsProvider" )
67+ public void testDeserializationWithWrongSecret (String encryptionSecret , String decryptionSecret )
68+ {
69+ assertThatThrownBy (() -> assertRoundTrip (Optional .ofNullable (encryptionSecret ), Optional .ofNullable (decryptionSecret )))
70+ .isInstanceOf (IllegalArgumentException .class )
71+ .hasMessage ("Decryption failed" )
72+ .hasStackTraceContaining ("Tag mismatch!" );
73+ }
74+
75+ @ DataProvider
76+ public Object [][] wrongSecretsProvider ()
77+ {
78+ return new Object [][] {
79+ {randomEncodedSecret (), randomEncodedSecret ()},
80+ {randomEncodedSecret (16 ), randomEncodedSecret (24 )},
81+ {null , null }, // This will generate two different secret keys
82+ {null , randomEncodedSecret ()},
83+ {randomEncodedSecret (), null }
84+ };
85+ }
86+
87+ @ Test
88+ public void testSerializationDeserializationRoundTripWithDifferentKeyLengths ()
89+ throws Exception
90+ {
91+ for (int keySize : new int [] {16 , 24 , 32 }) {
92+ String secret = randomEncodedSecret (keySize );
93+ assertRoundTrip (secret , secret );
94+ }
95+ }
96+
97+ @ Test
98+ public void testSerializationFailsWithWrongKeySize ()
99+ {
100+ for (int wrongKeySize : new int [] {8 , 64 , 128 }) {
101+ String tooShortSecret = randomEncodedSecret (wrongKeySize );
102+ assertThatThrownBy (() -> assertRoundTrip (tooShortSecret , tooShortSecret ))
103+ .hasStackTraceContaining ("The Key Encryption Key length must be 128 bits (16 bytes), 192 bits (24 bytes) or 256 bits (32 bytes)" );
104+ }
105+ }
106+
107+ private void assertRoundTrip (String serializerSecret , String deserializerSecret )
108+ throws Exception
109+ {
110+ assertRoundTrip (Optional .of (serializerSecret ), Optional .of (deserializerSecret ));
111+ }
112+
113+ private void assertRoundTrip (Optional <String > serializerSecret , Optional <String > deserializerSecret )
114+ throws Exception
115+ {
116+ JweTokenSerializer serializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ), serializerSecret );
117+ JweTokenSerializer deserializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ), deserializerSecret );
118+ Date expiration = new Calendar .Builder ().setDate (2023 , 6 , 22 ).build ().getTime ();
119+ TokenPair tokenPair = withAccessAndRefreshTokens (randomEncodedSecret (), expiration , randomEncodedSecret ());
120+ TokenPair postSerPair = deserializer .deserialize (serializer .serialize (tokenPair ));
121+ assertEquals (tokenPair .getAccessToken (), postSerPair .getAccessToken ());
122+ assertEquals (tokenPair .getRefreshToken (), postSerPair .getRefreshToken ());
123+ assertEquals (tokenPair .getExpiration (), postSerPair .getExpiration ());
124+ }
125+
59126 @ Test
60127 public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension ()
61128 throws Exception
62129 {
63130 TestingClock clock = new TestingClock ();
64131 JweTokenSerializer serializer = tokenSerializer (
65132 clock ,
66- succinctDuration (12 , MINUTES ));
133+ succinctDuration (12 , MINUTES ),
134+ randomEncodedSecret ());
67135 Date expiration = new Calendar .Builder ().setDate (2022 , 6 , 22 ).build ().getTime ();
68136 String serializedTokenPair = serializer .serialize (accessAndRefreshTokens ("access_token" , expiration , "refresh_token" ));
69137 clock .advanceBy (succinctDuration (10 , MINUTES ));
@@ -82,7 +150,8 @@ public void testTokenDeserializationAfterTimeoutAndExpirationExtension()
82150
83151 JweTokenSerializer serializer = tokenSerializer (
84152 clock ,
85- succinctDuration (12 , MINUTES ));
153+ succinctDuration (12 , MINUTES ),
154+ randomEncodedSecret ());
86155 Date expiration = new Calendar .Builder ().setDate (2022 , 6 , 22 ).build ().getTime ();
87156 String serializedTokenPair = serializer .serialize (accessAndRefreshTokens ("access_token" , expiration , "refresh_token" ));
88157
@@ -104,6 +173,40 @@ private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration
104173 tokenExpiration );
105174 }
106175
176+ private JweTokenSerializer tokenSerializer (Clock clock , Duration tokenExpiration , String encodedSecretKey )
177+ throws GeneralSecurityException , KeyLengthException
178+ {
179+ return tokenSerializer (clock , tokenExpiration , Optional .of (encodedSecretKey ));
180+ }
181+
182+ private JweTokenSerializer tokenSerializer (Clock clock , Duration tokenExpiration , Optional <String > secretKey )
183+ throws NoSuchAlgorithmException , KeyLengthException
184+ {
185+ RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig ();
186+ secretKey .ifPresent (refreshTokensConfig ::setSecretKey );
187+ return new JweTokenSerializer (
188+ refreshTokensConfig ,
189+ new Oauth2ClientStub (),
190+ "presto_coordinator_test_version" ,
191+ "presto_coordinator" ,
192+ "sub" ,
193+ clock ,
194+ tokenExpiration );
195+ }
196+
197+ private static String randomEncodedSecret ()
198+ {
199+ return randomEncodedSecret (24 );
200+ }
201+
202+ private static String randomEncodedSecret (int length )
203+ {
204+ Random random = new SecureRandom ();
205+ final byte [] buffer = new byte [length ];
206+ random .nextBytes (buffer );
207+ return Base64 .getEncoder ().encodeToString (buffer );
208+ }
209+
107210 static class Oauth2ClientStub
108211 implements OAuth2Client
109212 {
0 commit comments