diff --git a/Cargo.lock b/Cargo.lock index fa59854..97425ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1873,6 +1873,53 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "opentelemetry" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab70038c28ed37b97d8ed414b6429d343a8bbf44c9f79ec854f3a643029ba6d7" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "pin-project-lite", + "thiserror", + "tracing", +] + +[[package]] +name = "opentelemetry-prometheus" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b834e966ea5e2d03dfe5f2253f03d22cce21403ee940265070eeee96cee0bcc" +dependencies = [ + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "prometheus", + "protobuf", + "tracing", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "231e9d6ceef9b0b2546ddf52335785ce41252bc7474ee8ba05bfad277be13ab8" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "opentelemetry", + "percent-encoding", + "rand", + "serde_json", + "thiserror", + "tracing", +] + [[package]] name = "overload" version = "0.1.1" @@ -2029,6 +2076,27 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prometheus" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "memchr", + "parking_lot", + "protobuf", + "thiserror", +] + +[[package]] +name = "protobuf" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" + [[package]] name = "quick-error" version = "1.2.3" @@ -3534,9 +3602,13 @@ dependencies = [ "log", "nix", "notify", + "opentelemetry", + "opentelemetry-prometheus", + "opentelemetry_sdk", "parking_lot", "pin-project", "ppp", + "prometheus", "rcgen", "regex", "rstest", @@ -3569,6 +3641,7 @@ dependencies = [ "anyhow", "clap", "fdlimit", + "opentelemetry", "tokio", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 03fa2dd..9986220 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,10 @@ url = "2.5.4" urlencoding = "2.1.3" uuid = { version = "1.11.0", features = ["v7", "serde"] } derive_more = { version = "1.0.0", features = ["display", "error"] } +prometheus = "0.13.4" +opentelemetry = "0.27.1" +opentelemetry_sdk = "0.27.1" +opentelemetry-prometheus = "0.27.0" [target.'cfg(not(target_family = "unix"))'.dependencies] crossterm = { version = "0.28.1" } diff --git a/src/lib.rs b/src/lib.rs index fc6b49a..ac0d35f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; mod embedded_certificate; +pub mod metrics; mod protocols; mod restrictions; #[cfg(test)] @@ -371,7 +372,7 @@ pub async fn run_client(args: Client) -> anyhow::Result<()> { Ok(()) } -pub async fn run_server(args: Server) -> anyhow::Result<()> { +pub async fn run_server(args: Server, unbounded_metrics: bool) -> anyhow::Result<()> { let tls_config = if args.remote_addr.scheme() == "wss" { let tls_certificate = if let Some(cert_path) = &args.tls_certificate { tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate") @@ -449,7 +450,7 @@ pub async fn run_server(args: Server) -> anyhow::Result<()> { restriction_config: args.restrict_config, http_proxy, }; - let server = WsServer::new(server_config); + let server = WsServer::new(server_config, unbounded_metrics); info!( "Starting wstunnel server v{} with config {:?}", diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs new file mode 100644 index 0000000..ae15b96 --- /dev/null +++ b/src/metrics/mod.rs @@ -0,0 +1,84 @@ +use std::net::SocketAddr; + +use bytes::Bytes; +use http_body_util::Full; +use hyper::service::service_fn; +use hyper::{body, Request, Response, StatusCode, Version}; +use hyper_util::rt::TokioExecutor; + +use opentelemetry_sdk::metrics::SdkMeterProvider; +use prometheus::{Encoder, TextEncoder}; +use tokio::net::TcpListener; +use tracing::{error, info, warn}; + +pub async fn setup_metrics_provider(addr: &SocketAddr) -> anyhow::Result { + let registry = prometheus::Registry::new(); + + // configure OpenTelemetry to use this registry + let exporter = opentelemetry_prometheus::exporter() + .with_registry(registry.clone()) + .build()?; + + // set up a meter to create instruments + let provider = SdkMeterProvider::builder().with_reader(exporter).build(); + let listener = TcpListener::bind(addr).await?; + info!("Started metrics server on {}", addr); + + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(ret) => ret, + Err(err) => { + warn!("Error while accepting connection on metrics port {:?}", err); + continue; + } + }; + + let stream = hyper_util::rt::TokioIo::new(stream); + let conn = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + let fut = conn + .serve_connection( + stream, + service_fn(|req: Request| { + if req.uri().path() != "/metrics" { + return std::future::ready( + Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Default::default()), + ); + } + + if req.version() != Version::HTTP_11 { + return std::future::ready( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Full::::from("Failed to generate metrics")), + ); + } + + // Create handler local registry for ownership + let encoder = TextEncoder::new(); + let metric_families = registry.gather(); + let mut result = Vec::new(); + if let Err(err) = encoder.encode(&metric_families, &mut result) { + error!("Failed to encode prometheus metrics: {:?}", err); + return std::future::ready( + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Full::::from("Failed to generate metrics")), + ); + } + + std::future::ready(Ok(Response::new(Full::::from(result)))) + }), + ) + .await; + + if let Err(err) = fut { + warn!("Failed to handle metrics connection: {:?}", err) + } + } + }); + + return Ok(provider); +} diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index 3eee4c6..a36f451 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -3,7 +3,6 @@ use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse}; use crate::tunnel::server::WsServer; use crate::tunnel::transport; use crate::tunnel::transport::websocket::mk_websocket_tunnel; -use bytes::Bytes; use fastwebsockets::Role; use http_body_util::combinators::BoxBody; use http_body_util::Either; diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index 53984aa..83d10a0 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -1,18 +1,18 @@ use anyhow::anyhow; use futures_util::FutureExt; use http_body_util::Either; +use opentelemetry::metrics::Counter; +use opentelemetry::{global, KeyValue}; use std::fmt; use std::fmt::{Debug, Formatter}; use crate::protocols; use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr}; use arc_swap::ArcSwap; -use bytes::Bytes; -use http_body_util::combinators::BoxBody; use hyper::body::Incoming; use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; -use hyper::{http, Request, Response, StatusCode, Version}; +use hyper::{http, Request, StatusCode, Version}; use hyper_util::rt::{TokioExecutor, TokioTimer}; use parking_lot::Mutex; use socket2::SockRef; @@ -65,15 +65,29 @@ pub struct WsServerConfig { pub http_proxy: Option, } +pub struct WsServerMetrics { + pub unbounded: bool, + pub connections: Counter, +} + #[derive(Clone)] pub struct WsServer { pub config: Arc, + pub metrics: Arc, } impl WsServer { - pub fn new(config: WsServerConfig) -> Self { + pub fn new(config: WsServerConfig, unbounded_metrics: bool) -> Self { + let meter = global::meter_provider().meter("wstunnel"); Self { config: Arc::new(config), + metrics: Arc::new(WsServerMetrics { + unbounded: unbounded_metrics, + connections: meter + .u64_counter("connections_created") + .with_description("Counts the connections created. Attributes allow to split by remote host") + .build(), + }), } } @@ -127,6 +141,15 @@ impl WsServer { })?; info!("Tunnel accepted due to matched restriction: {}", restriction.name); + let attributes: &[KeyValue] = if self.metrics.unbounded { + &[ + KeyValue::new("remote_host", format!("{}", remote.host)), + KeyValue::new("remote_port", i64::from(remote.port)), + ] + } else { + &[] + }; + self.metrics.connections.add(1, attributes); let req_protocol = remote.protocol.clone(); let inject_cookie = req_protocol.is_dynamic_reverse_tunnel(); let tunnel = self diff --git a/src/tunnel/server/utils.rs b/src/tunnel/server/utils.rs index e4f6515..c572fb4 100644 --- a/src/tunnel/server/utils.rs +++ b/src/tunnel/server/utils.rs @@ -12,7 +12,6 @@ use hyper::body::{Body, Incoming}; use hyper::header::{HeaderValue, COOKIE, SEC_WEBSOCKET_PROTOCOL}; use hyper::{http, Request, Response, StatusCode}; use jsonwebtoken::TokenData; -use std::cmp::min; use std::net::IpAddr; use tracing::{error, info, warn}; use url::Host; diff --git a/wstunnel-cli/Cargo.toml b/wstunnel-cli/Cargo.toml index c8a3ae6..76fbe39 100644 --- a/wstunnel-cli/Cargo.toml +++ b/wstunnel-cli/Cargo.toml @@ -9,6 +9,7 @@ anyhow = "1.0.95" clap = { version = "4.5.23", features = ["derive", "env"] } fdlimit = "0.3.0" +opentelemetry = "0.27.1" tokio = { version = "1.42.0", features = ["full"] } @@ -19,4 +20,4 @@ wstunnel = { path = ".." , features = ["clap"] } [[bin]] name = "wstunnel" -path = "src/main.rs" \ No newline at end of file +path = "src/main.rs" diff --git a/wstunnel-cli/src/main.rs b/wstunnel-cli/src/main.rs index 31a169b..a0481c9 100644 --- a/wstunnel-cli/src/main.rs +++ b/wstunnel-cli/src/main.rs @@ -1,12 +1,15 @@ use clap::Parser; use std::io; +use std::net::SocketAddr; use std::str::FromStr; +use opentelemetry::global; use tracing::warn; use tracing_subscriber::filter::Directive; use tracing_subscriber::EnvFilter; use wstunnel::config::{Client, Server}; use wstunnel::LocalProtocol; use wstunnel::{run_client, run_server}; +use wstunnel::metrics; /// Use Websocket or HTTP2 protocol to tunnel {TCP,UDP} traffic /// wsTunnelClient <---> wsTunnelServer <---> RemoteHost @@ -43,6 +46,24 @@ pub struct Wstunnel { default_value = "INFO" )] log_lvl: String, + + /// Set the listen address for the prometheus metrics exporter. + #[arg( + long, + global = true, + verbatim_doc_comment, + default_value = None, + )] + metrics_provider_address: Option, + + /// Allow metrics to take up unbounded space (OOM risk!). + #[arg( + long, + global = true, + verbatim_doc_comment, + default_value = "false", + )] + metrics_unbounded: bool, } #[derive(clap::Subcommand, Debug)] @@ -84,12 +105,23 @@ async fn main() -> anyhow::Result<()> { warn!("Failed to set soft filelimit to hard file limit: {}", err) } + if let Some(addr) = args.metrics_provider_address { + match metrics::setup_metrics_provider(&addr).await { + Ok(provider) => { + let _ = global::set_meter_provider(provider); + } + Err(err) => { + panic!("Failed to setup metrics server: {err:?}") + } + } + } + match args.commands { Commands::Client(args) => { run_client(*args).await?; } - Commands::Server(args) => { - run_server(*args).await?; + Commands::Server(server_args) => { + run_server(*server_args, args.metrics_unbounded).await?; } }