Skip to content

Commit 262f36b

Browse files
authored
Merge pull request #18 from tembo-io/return_cols
Flexible return columns
2 parents 019fa00 + 99f426d commit 262f36b

6 files changed

Lines changed: 59 additions & 33 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "vectorize"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
edition = "2021"
55
publish = false
66

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,17 @@ Finally, search.
5757
```sql
5858
SELECT * FROM vectorize.search(
5959
job_name => 'product_search',
60-
return_col => 'product_name',
6160
query => 'accessories for mobile devices',
62-
api_key => 'my-openai-key"',
61+
api_key => 'my-openai-key',
62+
return_columns => ARRAY['product_id', 'product_name'],
6363
num_results => 3
6464
);
6565
```
6666

6767
```text
68-
search_results
69-
--------------------------------------------------------------------------------------------------
70-
{"value": "Phone Charger", "column": "product_name", "similarity_score": 0.8530797672121025}
71-
{"value": "Tablet Holder", "column": "product_name", "similarity_score": 0.8284493388477342}
72-
{"value": "Bluetooth Speaker", "column": "product_name", "similarity_score": 0.8255034841826178}
68+
search_results
69+
------------------------------------------------------------------------------------------------
70+
{"product_id": 13, "product_name": "Phone Charger", "similarity_score": 0.8564774308489237}
71+
{"product_id": 24, "product_name": "Tablet Holder", "similarity_score": 0.8295404213393001}
72+
{"product_id": 4, "product_name": "Bluetooth Speaker", "similarity_score": 0.8248579643539758}
7373
```

Trunk.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "The simplest implementation of LLM-backed vector search on Postgr
66
homepage = "https://github.com/tembo-io/pg_vectorize"
77
documentation = "https://github.com/tembo-io/pg_vectorize"
88
categories = ["orchestration", "machine_learning"]
9-
version = "0.2.0"
9+
version = "0.3.0"
1010

1111
[build]
1212
postgres_version = "15"

sql/vectorize--0.2.0--0.3.0.sql

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
DROP FUNCTION vectorize."table";
2+
3+
CREATE FUNCTION vectorize."table"(
4+
"table" TEXT, /* &str */
5+
"columns" TEXT[], /* alloc::vec::Vec<alloc::string::String> */
6+
"job_name" TEXT, /* core::option::Option<alloc::string::String> */
7+
"args" json, /* pgrx::datum::json::Json */
8+
"primary_key" TEXT, /* alloc::string::String */
9+
"schema" TEXT DEFAULT 'public', /* alloc::string::String */
10+
"update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */
11+
"transformer" vectorize.Transformer DEFAULT 'openai', /* vectorize::types::Transformer */
12+
"search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */
13+
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::init::TableMethod */
14+
"schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */
15+
) RETURNS TEXT /* core::result::Result<alloc::string::String, anyhow::Error> */
16+
LANGUAGE c /* Rust */
17+
18+
AS 'MODULE_PATHNAME', 'table_wrapper';
19+
20+
DROP FUNCTION vectorize."search";
21+
22+
CREATE FUNCTION vectorize."search"(
23+
"job_name" TEXT, /* &str */
24+
"query" TEXT, /* &str */
25+
"api_key" TEXT, /* &str */
26+
"return_columns" TEXT[] DEFAULT ARRAY['*']::text[], /* alloc::vec::Vec<alloc::string::String> */
27+
"num_results" INT DEFAULT 10 /* i32 */
28+
) RETURNS TABLE (
29+
"search_results" jsonb /* pgrx::datum::json::JsonB */
30+
)
31+
STRICT
32+
LANGUAGE c /* Rust */
33+
AS 'MODULE_PATHNAME', 'search_wrapper';

src/api.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ fn table(
116116
#[pg_extern]
117117
fn search(
118118
job_name: &str,
119-
return_col: &str,
120119
query: &str,
121120
api_key: &str,
121+
return_columns: default!(Vec<String>, "ARRAY['*']::text[]"),
122122
num_results: default!(i32, 10),
123123
) -> Result<TableIterator<'static, (name!(search_results, pgrx::JsonB),)>, spi::Error> {
124124
// note: this is not the most performant implementation
@@ -159,7 +159,7 @@ fn search(
159159
job_name,
160160
&schema,
161161
&table,
162-
return_col,
162+
&return_columns,
163163
num_results,
164164
&embeddings[0],
165165
)?;

src/search.rs

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,36 @@ pub fn cosine_similarity_search(
44
project: &str,
55
schema: &str,
66
table: &str,
7-
return_col: &str,
7+
return_columns: &[String],
88
num_results: i32,
99
embeddings: &[f64],
1010
) -> Result<Vec<(pgrx::JsonB,)>, spi::Error> {
1111
let emb = serde_json::to_string(&embeddings).expect("failed to serialize embeddings");
1212
let query = format!(
1313
"
14-
SELECT
15-
1 - ({project}_embeddings <=> '{emb}'::vector) AS cosine_similarity,
16-
*
14+
SELECT to_jsonb(t)
15+
as results FROM (
16+
SELECT
17+
1 - ({project}_embeddings <=> '{emb}'::vector) AS similarity_score,
18+
{cols}
1719
FROM {schema}.{table}
1820
WHERE {project}_updated_at is NOT NULL
19-
ORDER BY cosine_similarity DESC
20-
LIMIT {num_results};
21-
"
21+
ORDER BY similarity_score DESC
22+
LIMIT {num_results}
23+
) t
24+
",
25+
cols = return_columns.join(", "),
2226
);
2327
log!("query: {}", query);
2428
Spi::connect(|client| {
2529
let mut results: Vec<(pgrx::JsonB,)> = Vec::new();
2630
let tup_table = client.select(&query, None, None)?;
27-
2831
for row in tup_table {
29-
let v = row[return_col]
30-
.value::<String>()
31-
.expect("failed to get value");
32-
let score = row["cosine_similarity"]
33-
.value::<f64>()
34-
.expect("failed to get value");
35-
36-
let r = serde_json::json!({
37-
"column": return_col,
38-
"value": v,
39-
"similarity_score": score
40-
});
41-
results.push((pgrx::JsonB(r),));
32+
match row["results"].value()? {
33+
Some(r) => results.push((r,)),
34+
None => error!("failed to get results"),
35+
}
4236
}
43-
4437
Ok(results)
4538
})
4639
}

0 commit comments

Comments
 (0)