diff --git a/README.md b/README.md index 879cd46..4376bc5 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ -# TLS Proxy Server in Scala +# TLS (HTTPS) Proxy Server in Scala -Very simple HTTPS proxy server written in Scala 2.12. +Very simple HTTPS proxy server written in Scala 2.12 with no dependencies +beyond `scala-logging` -Can be used as a library with no dependencies beyond `scala-logging`, or as a -standalone program. +Can be used as a library, or as a standalone program. + +## Standalone ``` $ sbt run @@ -13,4 +15,70 @@ $ sbt run [info] Set current project to tlsproxy (in build file:/Users/erik/work/tlsproxy/) [info] Running tlsproxy.Main 14:31:46.707 [run-main-0] INFO tlsproxy.ServerHandler - Listening on port 3128... -``` \ No newline at end of file +18:03:15.466 [main] INFO tlsproxy.ServerHandler - Listening on port 3128... +18:03:22.651 [main] ERROR tlsproxy.TlsProxyHandler - /0:0:0:0:0:0:0:1:49672 -> google.com:443: error: connection closed: java.io.IOException: Connection reset by peer +18:04:22.806 [main] INFO tlsproxy.TlsProxyHandler - /0:0:0:0:0:0:0:1:49818 -> www.google.com/172.217.6.36:443 finished (up: 581, down: 4294) +18:04:56.131 [main] INFO tlsproxy.TlsProxyHandler - /0:0:0:0:0:0:0:1:49807 -> nginx.org/52.58.199.22:443 finished (up: 568, down: 187) +``` + +Now configure `localhost:3128` as proxy in your browser. + +``` +$ curl -I -x localhost:3128 https://woefdram.nl +HTTP/1.1 200 Connection Accepted +Proxy-Agent: TlsProxy/1.0 (github.com/erikvanzijst/scala_tlsproxy) +Content-Type: text/plain; charset=us-ascii +Content-Length: 0 + +HTTP/2 200 +server: nginx/1.18.0 +date: Tue, 17 Aug 2021 16:19:04 GMT +content-type: text/html +content-length: 612 +last-modified: Tue, 21 Apr 2020 14:09:01 GMT +etag: "5e9efe7d-264" +accept-ranges: bytes +``` + +## Library + +To use it as a library in-process: + +```scala +import tlsproxy.TlsProxy + +new TlsProxy(3128).run() +``` + +The `run()` does not create any threads and run the entire proxy on the +calling thread. It does not return. + +To move it to the background, pass it to a `Thread` or `Executor`: + +```scala +import tlsproxy.TlsProxy +import java.util.concurrent.Executors + +val executor = Executors.newSingleThreadExecutor() +executor.submit(new TlsProxy(3128)) +``` + +## Caveat emptor + +This is only implements the `CONNECT` method and can therefor only proxy HTTPS +requests. It does not support unencrypted proxy requests using `GET`. + +Proxy requests for HTTP (non-TLS) `GET` requests result in an error and the +connection getting closed: + +``` +18:08:53.604 [main] ERROR tlsproxy.TlsProxyHandler - /0:0:0:0:0:0:0:1:51043 -> unconnected: error: connection closed: java.io.IOException: Malformed request +``` + +## Robustness (or lack thereof) + +* This implementation is totally susceptible to all kinds of [slowloris attacks](https://en.wikipedia.org/wiki/Slowloris_(computer_security). +* It does not support client authentication +* Uses only 1 thread and cannot currently scale to multiple cores +* Does not restrict non-standard upstream ports +* Undoubtedly riddled with bugs diff --git a/src/main/scala/tlsproxy/ClientHandler.scala b/src/main/scala/tlsproxy/EchoHandler.scala similarity index 92% rename from src/main/scala/tlsproxy/ClientHandler.scala rename to src/main/scala/tlsproxy/EchoHandler.scala index 7c4cb06..8453938 100644 --- a/src/main/scala/tlsproxy/ClientHandler.scala +++ b/src/main/scala/tlsproxy/EchoHandler.scala @@ -8,7 +8,7 @@ import com.typesafe.scalalogging.StrictLogging import scala.collection.JavaConverters._ -class ClientHandler(selector: Selector, socketChannel: SocketChannel) extends KeyHandler with StrictLogging { +class EchoHandler(selector: Selector, socketChannel: SocketChannel) extends KeyHandler with StrictLogging { socketChannel.configureBlocking(false) private val peer = socketChannel.getRemoteAddress private val buffer = ByteBuffer.allocate((1 << 16) - 1) @@ -54,10 +54,11 @@ class ClientHandler(selector: Selector, socketChannel: SocketChannel) extends Ke } def close(): Unit = + shutdown = true if (selectionKey.isValid) { selectionKey.cancel() socketChannel.close() logger.info("{} connection closed (total connected clients: {})", - peer, selector.keys().asScala.count(_.attachment().isInstanceOf[ClientHandler]) - 1) + peer, selector.keys().asScala.count(_.attachment().isInstanceOf[EchoHandler]) - 1) } } diff --git a/src/main/scala/tlsproxy/Main.scala b/src/main/scala/tlsproxy/Main.scala index 3dcbb4b..f40917c 100644 --- a/src/main/scala/tlsproxy/Main.scala +++ b/src/main/scala/tlsproxy/Main.scala @@ -2,10 +2,19 @@ package tlsproxy import java.util.concurrent.{ExecutorService, Executors} +import ch.qos.logback.classic.Level import com.typesafe.scalalogging.StrictLogging +import org.slf4j.LoggerFactory import scopt.OptionParser object Main extends StrictLogging { + + // Suppress debug when running from the shell + Seq("tlsproxy.TlsProxyHandler", "tlsproxy.ServerHandler", "tlsproxy.Pipe") + .map(LoggerFactory.getLogger) + .map(_.asInstanceOf[ch.qos.logback.classic.Logger]) + .foreach(_.setLevel(Level.INFO)) + def main(args: Array[String]): Unit = { case class Config(port: Int = 3128) diff --git a/src/main/scala/tlsproxy/Pipe.scala b/src/main/scala/tlsproxy/Pipe.scala new file mode 100644 index 0000000..de1ddeb --- /dev/null +++ b/src/main/scala/tlsproxy/Pipe.scala @@ -0,0 +1,50 @@ +package tlsproxy + +import java.nio.ByteBuffer +import java.nio.channels.{SelectionKey, SocketChannel} + +import com.typesafe.scalalogging.StrictLogging + +class Pipe(fromKey: SelectionKey, fromChannel: SocketChannel, toKey: SelectionKey, toChannel: SocketChannel) + extends KeyHandler with StrictLogging { + + private val buffer = ByteBuffer.allocate(1 << 16) + private var shutdown = false + private var count: Long = 0 + + def bytes: Long = count + + def isClosed: Boolean = shutdown && buffer.position() == 0 + + override def process(): Unit = { + + if (fromKey.isValid && fromKey.isReadable && !shutdown) { + val len = fromChannel.read(buffer) + if (len == -1) { + logger.debug("{} -> {} EOF reached", fromChannel.getRemoteAddress, toChannel.getRemoteAddress) + shutdown = true + } else { + count = count + len + } + } + + if (toKey.isValid && toKey.isWritable) { + buffer.flip() + toChannel.write(buffer) + buffer.compact() + } + + if (shutdown && buffer.position() == 0) toChannel.shutdownOutput() + + if (toKey.isValid) { + toKey.interestOps( + if (buffer.position() > 0) toKey.interestOps() | SelectionKey.OP_WRITE + else toKey.interestOps() & ~SelectionKey.OP_WRITE) + } + if (fromKey.isValid) { + fromKey.interestOps( + if (!buffer.hasRemaining || shutdown) fromKey.interestOps() & ~SelectionKey.OP_READ + else fromKey.interestOps() | SelectionKey.OP_READ) + } + } +} diff --git a/src/main/scala/tlsproxy/ServerHandler.scala b/src/main/scala/tlsproxy/ServerHandler.scala index ee8d940..688a53d 100644 --- a/src/main/scala/tlsproxy/ServerHandler.scala +++ b/src/main/scala/tlsproxy/ServerHandler.scala @@ -17,8 +17,8 @@ class ServerHandler(selector: Selector, port: Int) extends KeyHandler with Stric override def process(): Unit = { val channel = serverSocketChannel.accept() - new ClientHandler(selector, channel) - logger.info("New incoming connection from {} (total connected clients: {})", - channel.getRemoteAddress, selector.keys().asScala.count(_.attachment().isInstanceOf[ClientHandler])) + new TlsProxyHandler(selector, channel) + logger.debug("New incoming connection from {} (total connected clients: {})", + channel.getRemoteAddress, selector.keys().asScala.count(_.attachment().isInstanceOf[TlsProxyHandler])) } } diff --git a/src/main/scala/tlsproxy/TlsProxy.scala b/src/main/scala/tlsproxy/TlsProxy.scala index a8c9254..441239f 100644 --- a/src/main/scala/tlsproxy/TlsProxy.scala +++ b/src/main/scala/tlsproxy/TlsProxy.scala @@ -2,7 +2,9 @@ package tlsproxy import java.nio.channels.Selector +import ch.qos.logback.classic.Level import com.typesafe.scalalogging.StrictLogging +import org.slf4j.LoggerFactory trait KeyHandler { def process(): Unit @@ -18,7 +20,7 @@ class TlsProxy(port: Int) extends StrictLogging with Runnable with AutoCloseable override def run(): Unit = { val selector = Selector.open - val server = new ServerHandler(selector, port) + new ServerHandler(selector, port) while (true) { if (selector.select(5000) > 0) { diff --git a/src/main/scala/tlsproxy/TlsProxyHandler.scala b/src/main/scala/tlsproxy/TlsProxyHandler.scala new file mode 100644 index 0000000..293dfba --- /dev/null +++ b/src/main/scala/tlsproxy/TlsProxyHandler.scala @@ -0,0 +1,197 @@ +package tlsproxy + +import java.io.IOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{SelectionKey, Selector, SocketChannel, UnresolvedAddressException} +import java.nio.charset.StandardCharsets + +import com.typesafe.scalalogging.StrictLogging +import tlsproxy.TlsProxyHandler.userAgent + +import scala.collection.JavaConverters._ +import scala.util.Try + +object ProxyPhase extends Enumeration { + type ProxyPhase = Value + val Destination, Headers, Response, Connecting, Established, Error = Value +} + +object TlsProxyHandler { + private val destPattern = "CONNECT ([^:]+):([0-9]+) HTTP/1.1".r + private val userAgent = "TlsProxy/1.0 (github.com/erikvanzijst/scala_tlsproxy)" +} + +class TlsProxyHandler(selector: Selector, clientChannel: SocketChannel) extends KeyHandler with StrictLogging { + import ProxyPhase._ + + clientChannel.configureBlocking(false) + private val peer = clientChannel.getRemoteAddress + + private val clientKey = clientChannel.register(selector, SelectionKey.OP_READ, this) // client initiating the connection + private val clientBuffer = ByteBuffer.allocate(1 << 16) // client-to-server + + private var serverKey: SelectionKey = _ // the upstream server + private var serverChannel: SocketChannel = _ + private val serverBuffer = ByteBuffer.allocate(1 << 16) // server-to-client + + private var upstreamPipe: Pipe = _ + private var downstreamPipe: Pipe = _ + + private var shutdown = false + private var destination: (String, Int) = _ + + private var phase = Destination + + private def readClient(): Unit = { + if (clientKey.isValid && clientKey.isReadable && clientChannel.read(clientBuffer) == -1) + throw new IOException(s"$peer unexpected EOF from client") + if (!clientBuffer.hasRemaining) + throw new IOException(s"$peer handshake overflow") + } + + + private def readLine(): Option[String] = { + readClient() + clientBuffer.flip() + + val s = StandardCharsets.US_ASCII.decode(clientBuffer).toString + s.indexOf("\r\n") match { + case eol if eol != -1 => + clientBuffer.position(eol + 2) + clientBuffer.compact() + Some(s.substring(0, eol)) + case _ => + clientBuffer.position(0) + clientBuffer.compact() + None + } + } + + private def startResponse(statusCode: Int, statusLine: String, body: String): Unit = { + val resp = response(statusCode, statusLine, body) + serverBuffer.put(resp, 0, resp.length) + clientKey.interestOps(SelectionKey.OP_WRITE) + } + + private def response(statusCode: Int, statusLine: String, body: String): Array[Byte] = + (s"HTTP/1.1 $statusCode $statusLine\r\n" + + s"Proxy-Agent: $userAgent\r\n" + + "Content-Type: text/plain; charset=us-ascii\r\n" + + s"Content-Length: ${body.length}\r\n" + + "\r\n" + + body).getBytes(StandardCharsets.US_ASCII) + + override def process(): Unit = { + try { + + if (phase == Destination) + readLine().map(TlsProxyHandler.destPattern.findFirstMatchIn(_)).foreach { + case Some(m) => + destination = (m.group(1), m.group(2).toInt) + logger.debug("{} wants to connect to {}:{}...", peer, destination._1, destination._2) + phase = Headers + case _ => throw new IOException(s"Malformed request") + } + + if (phase == Headers) + Iterator.continually(readLine()).takeWhile(_.isDefined).flatten.foreach { + case header if header == "" => + logger.debug("{} all headers consumed, initiating upstream connection...", peer) + + serverChannel = SocketChannel.open() + serverChannel.configureBlocking(false) + serverKey = serverChannel.register(selector, SelectionKey.OP_CONNECT, this) + clientKey.interestOps(0) // stop reading while we connect upstream or server a response + + phase = Try { + if (serverChannel.connect(new InetSocketAddress(destination._1, destination._2))) { + startResponse(200, "Connection Accepted", "") + Response + } else { + Connecting + } + }.recover { + case _: UnresolvedAddressException => + logger.info(s"{} cannot resolve {}", peer, destination._1) + startResponse(502, "Bad Gateway", s"Failed to resolve ${destination._1}") + Error + case iae: IllegalArgumentException => + startResponse(400, "Bad Request", s"${iae.getMessage}\n") + Error + }.get + + case header => logger.debug("{} ignoring header {}", peer, header) + } + + if (phase == Connecting) + if (serverKey.isConnectable) + phase = Try { + serverChannel.finishConnect() + startResponse(200, "Connection Accepted", "") + Response + }.recover { case ioe: IOException => + startResponse(502, "Gateway Error", s"${ioe.getMessage}\n") + Error + }.get + + if (phase == Response) + if (clientKey.isWritable) { + serverBuffer.flip() + clientChannel.write(serverBuffer) + serverBuffer.compact() + + if (serverBuffer.position() == 0) { + clientKey.interestOps(SelectionKey.OP_READ) + serverKey.interestOps(SelectionKey.OP_READ) + + upstreamPipe = new Pipe(clientKey, clientChannel, serverKey, serverChannel) + downstreamPipe = new Pipe(serverKey, serverChannel, clientKey, clientChannel) + + logger.debug("{} 200 OK sent to client -- TLS connection to {} ready", peer, serverChannel.getRemoteAddress) + phase = Established + } + } + + if (phase == Established) { + upstreamPipe.process() + downstreamPipe.process() + if (upstreamPipe.isClosed && downstreamPipe.isClosed) { + logger.info("{} -> {} finished (up: {}, down: {})", + peer, serverChannel.getRemoteAddress, upstreamPipe.bytes, downstreamPipe.bytes) + close() + } + } + + if (phase == Error) + if (clientKey.isWritable) { + serverBuffer.flip() + clientChannel.write(serverBuffer) + serverBuffer.compact() + + if (serverBuffer.position() == 0) { + close() + } + } + + } catch { + case e: IOException => + logger.error(s"$peer -> ${Option(destination).map(s => s"${s._1}:${s._2}").getOrElse("unconnected")}: error: connection closed: ${e.getClass.getName}: ${e.getMessage}") + close() + } + } + + def close(): Unit = { + shutdown = true + if (clientKey.isValid) { + clientKey.cancel() + clientChannel.close() + } + if (serverKey != null && serverKey.isValid) { + serverKey.cancel() + serverChannel.close() + } + logger.debug("{} connection closed (total connected clients: {})", + peer, selector.keys().asScala.count(_.attachment().isInstanceOf[TlsProxyHandler]) - 1) + } +}