diff --git a/druid/src/main/scala/com/yahoo/maha/executor/druid/HttpUtils.scala b/druid/src/main/scala/com/yahoo/maha/executor/druid/HttpUtils.scala index 5ffa6dcfd..a2c557d2c 100644 --- a/druid/src/main/scala/com/yahoo/maha/executor/druid/HttpUtils.scala +++ b/druid/src/main/scala/com/yahoo/maha/executor/druid/HttpUtils.scala @@ -6,14 +6,17 @@ package com.yahoo.maha.executor.druid * Created by vivekch on 3/2/16. */ -import java.io.Closeable -import javax.net.ssl.SSLContext +import java.io.{Closeable, FileInputStream} +import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory} import com.yahoo.maha.executor.druid.filters.TimeoutThrottlingFilter import grizzled.slf4j.Logging import io.netty.handler.ssl.{SslContext, SslContextBuilder} import org.apache.http.HttpHeaders import org.apache.http.entity.ContentType -import org.asynchttpclient.{AsyncHttpClient, AsyncHttpClientConfig, BoundRequestBuilder, DefaultAsyncHttpClient, DefaultAsyncHttpClientConfig, Response}; +import org.asynchttpclient.{AsyncHttpClient, AsyncHttpClientConfig, BoundRequestBuilder, DefaultAsyncHttpClient, DefaultAsyncHttpClientConfig, Response} + +import java.security.spec.PKCS8EncodedKeySpec +import java.security.{KeyFactory, KeyStore, PrivateKey}; class HttpUtils(config:AsyncHttpClientConfig, enableRetryOn500: Boolean, retryDelayMillis: Int, maxRetry: Int) extends Logging with Closeable { @@ -108,7 +111,64 @@ object ClientConfig{ , customizeBuilder: Option[(DefaultAsyncHttpClientConfig.Builder) => Unit] = None ): AsyncHttpClientConfig ={ val builder = new DefaultAsyncHttpClientConfig.Builder() - val sslContext: SslContext = SslContextBuilder.forClient().build() + //val sslContext: SslContext = SslContextBuilder.forClient().build() + // test with other mtls + val clientTrustStorePath = "" + val clientTrustStorePassword = "" + val clientPublicCertPath = "" + val clientPrivateKeyPath = "" + + def getCertificate(path: String): java.security.cert.Certificate = { + val inputStream = new FileInputStream(path) + val certificate = java.security.cert.CertificateFactory.getInstance("X.509").generateCertificate(inputStream) + inputStream.close() + certificate + } + + def getPrivateKey(path: String, keyFactory: KeyFactory): PrivateKey = { + val inputStream = new FileInputStream(path) + val privateKeyBytes = readPemFile(inputStream) + inputStream.close() + val privateKeySpec = new PKCS8EncodedKeySpec(privateKeyBytes) + keyFactory.generatePrivate(privateKeySpec) + } + + def readPemFile(inputStream: FileInputStream): Array[Byte] = { + val pemBytes = Stream.continually(inputStream.read).takeWhile(_ != -1).map(_.toByte).toArray + val header = "-----BEGIN PRIVATE KEY-----" + val footer = "-----END PRIVATE KEY-----" + val pemString = new String(pemBytes) + val pemWithoutHeaderAndFooter = pemString + .stripPrefix(header) + .stripSuffix(footer) + .replaceAll("\\s", "") + javax.xml.bind.DatatypeConverter.parseBase64Binary(pemWithoutHeaderAndFooter) + } + + // Load the trust store + val trustStore = KeyStore.getInstance("JKS") + trustStore.load(new FileInputStream(clientTrustStorePath), clientTrustStorePassword.toCharArray) + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + trustManagerFactory.init(trustStore) + + // Load the client certificate and key + val keyFactory = KeyFactory.getInstance("RSA") + val privateKey = getPrivateKey(clientPrivateKeyPath, keyFactory) + val certificate = getCertificate(clientPublicCertPath) + + // Load the client certificate and key + val keyStore = KeyStore.getInstance("PKCS12") + keyStore.load(null) + keyStore.setCertificateEntry("cert", certificate) + keyStore.setKeyEntry("key", privateKey, null, Array(getCertificate(clientPublicCertPath))) + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) + keyManagerFactory.init(keyStore, null) + + val sslContext: SslContext = SslContextBuilder.forClient() + .trustManager(trustManagerFactory) + .keyManager(keyManagerFactory) + .build() + customizeBuilder.foreach(_(builder)) builder .setMaxConnectionsPerHost(maxConnectionsPerHost)