|
1 | | -use log::{error, info}; |
| 1 | +use std::collections::HashMap; |
| 2 | +use std::net::SocketAddr; |
| 3 | +use std::net::ToSocketAddrs; |
2 | 4 | use std::sync::Arc; |
| 5 | +use std::time::Duration; |
3 | 6 | use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; |
4 | | -use tokio::net::TcpStream; |
| 7 | +use tokio::net::{TcpListener, TcpStream}; |
| 8 | +use tokio::sync::RwLock; |
5 | 9 | use tokio::time::timeout; |
| 10 | +use tracing::{error, info}; |
| 11 | +use url::Url; |
| 12 | +use vectorize_core::types::VectorizeJob; |
6 | 13 |
|
7 | 14 | use super::message_parser::{log_message_processing, try_parse_complete_message}; |
8 | 15 | use super::protocol::{BUFFER_SIZE, ProxyConfig, WireProxyError}; |
@@ -129,3 +136,55 @@ where |
129 | 136 | info!("Standard proxy stream closed: {total_bytes} bytes transferred"); |
130 | 137 | Ok(()) |
131 | 138 | } |
| 139 | + |
| 140 | +pub async fn start_postgres_proxy( |
| 141 | + proxy_port: u16, |
| 142 | + database_url: String, |
| 143 | + job_cache: Arc<RwLock<HashMap<String, VectorizeJob>>>, |
| 144 | + db_pool: sqlx::PgPool, |
| 145 | +) -> Result<(), Box<dyn std::error::Error>> { |
| 146 | + let bind_address = "0.0.0.0"; |
| 147 | + let timeout = 30; |
| 148 | + |
| 149 | + let listen_addr: SocketAddr = format!("{}:{}", bind_address, proxy_port).parse()?; |
| 150 | + |
| 151 | + let url = Url::parse(&database_url)?; |
| 152 | + let postgres_host = url.host_str().unwrap(); |
| 153 | + let postgres_port = url.port().unwrap(); |
| 154 | + |
| 155 | + let postgres_addr: SocketAddr = format!("{postgres_host}:{postgres_port}") |
| 156 | + .to_socket_addrs()? |
| 157 | + .next() |
| 158 | + .ok_or("Failed to resolve PostgreSQL host address")?; |
| 159 | + |
| 160 | + let config = Arc::new(ProxyConfig { |
| 161 | + postgres_addr, |
| 162 | + timeout: Duration::from_secs(timeout), |
| 163 | + jobmap: job_cache, |
| 164 | + db_pool, |
| 165 | + prepared_statements: Arc::new(RwLock::new(HashMap::new())), |
| 166 | + }); |
| 167 | + |
| 168 | + info!("Proxy listening on: {listen_addr}"); |
| 169 | + info!("Forwarding to PostgreSQL at: {postgres_addr}"); |
| 170 | + |
| 171 | + let listener = TcpListener::bind(listen_addr).await?; |
| 172 | + |
| 173 | + loop { |
| 174 | + match listener.accept().await { |
| 175 | + Ok((client_stream, client_addr)) => { |
| 176 | + info!("New proxy connection from: {client_addr}"); |
| 177 | + |
| 178 | + let config = Arc::clone(&config); |
| 179 | + tokio::spawn(async move { |
| 180 | + if let Err(e) = handle_connection_with_timeout(client_stream, config).await { |
| 181 | + error!("Proxy connection error from {client_addr}: {e}"); |
| 182 | + } |
| 183 | + }); |
| 184 | + } |
| 185 | + Err(e) => { |
| 186 | + error!("Failed to accept proxy connection: {e}"); |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | +} |
0 commit comments