Skip to content

Commit adf057e

Browse files
authored
refactor appState and proxy (#258)
1 parent db05f15 commit adf057e

18 files changed

Lines changed: 444 additions & 239 deletions

File tree

Cargo.lock

Lines changed: 8 additions & 34 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ serde = "1.0.219"
3131
serde_json = "1.0"
3232
sqlparser = "0.51"
3333
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "uuid", "time"] }
34+
tracing = "0.1"
35+
tracing-log = "0.1"
36+
tracing-subscriber = "0.3.20"
3437
thiserror = "2.0.12"
3538
tiktoken-rs = "0.7.0"
3639
tokio = { version = "1.0", features = ["full"] }

extension/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ clean:
4242
setup.dependencies: install-pg_cron install-pgvector install-pgmq install-vectorscale
4343
setup.shared_preload_libraries:
4444
echo "shared_preload_libraries = 'pg_cron, vectorize'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf
45+
echo "cron.database_name = 'postgres'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf
4546
setup.urls:
4647
echo "vectorize.embedding_service_url = 'http://localhost:3000/v1'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf
4748
echo "vectorize.ollama_service_url = 'http://localhost:3001'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf
@@ -93,7 +94,7 @@ test-integration:
9394
cargo test ${TEST_NAME} -- --ignored --test-threads=1 --nocapture
9495

9596
test-unit:
96-
cargo test ${TEST_NAME} -- --test-threads=1
97+
cargo test ${TEST_NAME} -- --test-threads=1 --nocapture
9798

9899
test-version:
99100
git fetch --tags

extension/src/api.rs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,40 +176,6 @@ fn encode(
176176
Ok(transform(input, &model, api_key).remove(0))
177177
}
178178

179-
#[allow(clippy::too_many_arguments)]
180-
#[deprecated(since = "0.22.0", note = "Please use vectorize.table() instead")]
181-
#[pg_extern]
182-
fn init_rag(
183-
agent_name: &str,
184-
table_name: &str,
185-
unique_record_id: &str,
186-
// column that have data we want to be able to chat with
187-
column: &str,
188-
schema: default!(&str, "'public'"),
189-
index_dist_type: default!(types::IndexDist, "'pgv_hnsw_cosine'"),
190-
// transformer model to use in vector-search
191-
transformer: default!(&str, "'sentence-transformers/all-MiniLM-L6-v2'"),
192-
table_method: default!(types::TableMethod, "'join'"),
193-
schedule: default!(&str, "'* * * * *'"),
194-
) -> Result<String> {
195-
pgrx::warning!("DEPRECATED: vectorize.init_rag() will be removed in a future version. Please use vectorize.table() instead.");
196-
// chat only supports single columns transform
197-
let columns = vec![column.to_string()];
198-
let transformer_model = Model::new(transformer)?;
199-
init_table(
200-
agent_name,
201-
schema,
202-
table_name,
203-
columns,
204-
unique_record_id,
205-
None,
206-
index_dist_type.into(),
207-
&transformer_model,
208-
table_method.into(),
209-
schedule,
210-
)
211-
}
212-
213179
/// creates a table indexed with embeddings for chat completion workloads
214180
#[pg_extern]
215181
fn rag(

extension/src/executor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub fn batch_texts(
2525
return TableIterator::new(vec![record_ids].into_iter().map(|arr| (arr,)));
2626
}
2727

28-
let num_batches = (total_records + batch_size - 1) / batch_size;
28+
let num_batches = total_records.div_ceil(batch_size);
2929

3030
let mut batches = Vec::with_capacity(num_batches);
3131

extension/src/guc.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,6 @@ pub fn get_guc(guc: VectorizeGuc) -> Option<String> {
250250
}
251251
}
252252

253-
#[allow(dead_code)]
254253
fn handle_cstr(cstr: &CStr) -> Result<String> {
255254
if let Ok(s) = cstr.to_str() {
256255
Ok(s.to_owned())

extension/src/search.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ pub fn cosine_similarity_search(
429429
num_results,
430430
where_clause,
431431
),
432-
TableMethod::join => query::join_table_cosine_similarity(
432+
TableMethod::join => join_table_cosine_similarity(
433433
project,
434434
&job_params.schema,
435435
&job_params.relation,
@@ -452,6 +452,52 @@ pub fn cosine_similarity_search(
452452
})
453453
}
454454

455+
pub fn join_table_cosine_similarity(
456+
project: &str,
457+
schema: &str,
458+
table: &str,
459+
join_key: &str,
460+
return_columns: &[String],
461+
num_results: i32,
462+
where_clause: Option<String>,
463+
) -> String {
464+
let cols = &return_columns
465+
.iter()
466+
.map(|s| format!("t0.{s}"))
467+
.collect::<Vec<_>>()
468+
.join(",");
469+
let where_str = if let Some(w) = where_clause {
470+
prepare_filter(&w, join_key)
471+
} else {
472+
"".to_string()
473+
};
474+
let inner_query = format!(
475+
"
476+
SELECT
477+
{join_key},
478+
1 - (embeddings <=> $1::vector) AS similarity_score
479+
FROM vectorize._embeddings_{project}
480+
ORDER BY similarity_score DESC
481+
"
482+
);
483+
format!(
484+
"
485+
SELECT to_jsonb(t) as results
486+
FROM (
487+
SELECT {cols}, t1.similarity_score
488+
FROM
489+
(
490+
{inner_query}
491+
) t1
492+
INNER JOIN {schema}.{table} t0 on t0.{join_key} = t1.{join_key}
493+
{where_str}
494+
) t
495+
ORDER BY t.similarity_score DESC
496+
LIMIT {num_results};
497+
"
498+
)
499+
}
500+
455501
fn single_table_cosine_similarity(
456502
project: &str,
457503
schema: &str,
@@ -482,3 +528,9 @@ fn single_table_cosine_similarity(
482528
cols = return_columns.join(", "),
483529
)
484530
}
531+
532+
// transform user's where_sql into the format search query expects
533+
fn prepare_filter(filter: &str, pkey: &str) -> String {
534+
let wc = filter.replace(pkey, &format!("t0.{pkey}"));
535+
format!("AND {wc}")
536+
}

extension/tests/util.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ pub mod common {
77
use sqlx::{Pool, Postgres, Row};
88
use url::{ParseError, Url};
99

10-
#[allow(dead_code)]
1110
#[derive(FromRow, Debug, serde::Deserialize)]
1211
pub struct SearchResult {
1312
pub product_id: i32,
@@ -16,7 +15,6 @@ pub mod common {
1615
pub similarity_score: f64,
1716
}
1817

19-
#[allow(dead_code)]
2018
#[derive(FromRow, Debug, Serialize)]
2119
pub struct SearchJSON {
2220
pub search_results: serde_json::Value,

proxy/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ serde_json = { workspace = true }
1313
sqlx = { workspace = true}
1414
thiserror = { workspace = true }
1515
tokio = { workspace = true }
16+
tracing = { workspace = true }
17+
tracing-subscriber = { workspace = true }
18+
url = { workspace = true }
1619

1720
pgwire = { version = "0.30", features = ["server-api-aws-lc-rs"] }

proxy/src/proxy.rs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
use log::{error, info};
1+
use std::collections::HashMap;
2+
use std::net::SocketAddr;
3+
use std::net::ToSocketAddrs;
24
use std::sync::Arc;
5+
use std::time::Duration;
36
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4-
use tokio::net::TcpStream;
7+
use tokio::net::{TcpListener, TcpStream};
8+
use tokio::sync::RwLock;
59
use tokio::time::timeout;
10+
use tracing::{error, info};
11+
use url::Url;
12+
use vectorize_core::types::VectorizeJob;
613

714
use super::message_parser::{log_message_processing, try_parse_complete_message};
815
use super::protocol::{BUFFER_SIZE, ProxyConfig, WireProxyError};
@@ -129,3 +136,55 @@ where
129136
info!("Standard proxy stream closed: {total_bytes} bytes transferred");
130137
Ok(())
131138
}
139+
140+
pub async fn start_postgres_proxy(
141+
proxy_port: u16,
142+
database_url: String,
143+
job_cache: Arc<RwLock<HashMap<String, VectorizeJob>>>,
144+
db_pool: sqlx::PgPool,
145+
) -> Result<(), Box<dyn std::error::Error>> {
146+
let bind_address = "0.0.0.0";
147+
let timeout = 30;
148+
149+
let listen_addr: SocketAddr = format!("{}:{}", bind_address, proxy_port).parse()?;
150+
151+
let url = Url::parse(&database_url)?;
152+
let postgres_host = url.host_str().unwrap();
153+
let postgres_port = url.port().unwrap();
154+
155+
let postgres_addr: SocketAddr = format!("{postgres_host}:{postgres_port}")
156+
.to_socket_addrs()?
157+
.next()
158+
.ok_or("Failed to resolve PostgreSQL host address")?;
159+
160+
let config = Arc::new(ProxyConfig {
161+
postgres_addr,
162+
timeout: Duration::from_secs(timeout),
163+
jobmap: job_cache,
164+
db_pool,
165+
prepared_statements: Arc::new(RwLock::new(HashMap::new())),
166+
});
167+
168+
info!("Proxy listening on: {listen_addr}");
169+
info!("Forwarding to PostgreSQL at: {postgres_addr}");
170+
171+
let listener = TcpListener::bind(listen_addr).await?;
172+
173+
loop {
174+
match listener.accept().await {
175+
Ok((client_stream, client_addr)) => {
176+
info!("New proxy connection from: {client_addr}");
177+
178+
let config = Arc::clone(&config);
179+
tokio::spawn(async move {
180+
if let Err(e) = handle_connection_with_timeout(client_stream, config).await {
181+
error!("Proxy connection error from {client_addr}: {e}");
182+
}
183+
});
184+
}
185+
Err(e) => {
186+
error!("Failed to accept proxy connection: {e}");
187+
}
188+
}
189+
}
190+
}

0 commit comments

Comments
 (0)