Skip to content

Commit 51d95ea

Browse files
authored
Pass through model name on embedding service calls (#35)
* add transform to sql api * pass through transformer name * clippy * any SentenceTransformer * add migration script * bump toml: * update default * update comments
1 parent d4ed24e commit 51d95ea

14 files changed

Lines changed: 123 additions & 144 deletions

File tree

Cargo.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "vectorize"
3-
version = "0.7.0"
3+
version = "0.8.0"
44
edition = "2021"
55
publish = false
66

@@ -9,16 +9,18 @@ crate-type = ["cdylib"]
99

1010
[features]
1111
default = ["pg15"]
12+
pg14 = ["pgrx/pg14", "pgrx-tests/pg14"]
1213
pg15 = ["pgrx/pg15", "pgrx-tests/pg15"]
14+
pg16 = ["pgrx/pg16", "pgrx-tests/pg16"]
1315
pg_test = []
1416

1517
[dependencies]
1618
anyhow = "1.0.72"
1719
chrono = {version = "0.4.26", features = ["serde"] }
1820
lazy_static = "1.4.0"
1921
log = "0.4.19"
20-
pgmq = "0.24.0"
21-
pgrx = "0.11.0"
22+
pgmq = "0.26.0"
23+
pgrx = "0.11.2"
2224
postgres-types = "0.2.5"
2325
regex = "1.9.2"
2426
reqwest = {version = "0.11.18", features = ["json"] }
@@ -35,7 +37,7 @@ tokio = {version = "1.29.1", features = ["rt-multi-thread"] }
3537
url = "2.4.0"
3638

3739
[dev-dependencies]
38-
pgrx-tests = "0.11.0"
40+
pgrx-tests = "0.11.2"
3941
rand = "0.8.5"
4042
whoami = "1.4.1"
4143

Trunk.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
66
homepage = "https://github.com/tembo-io/pg_vectorize"
77
documentation = "https://github.com/tembo-io/pg_vectorize"
88
categories = ["orchestration", "machine_learning"]
9-
version = "0.7.0"
9+
version = "0.8.0"
1010

1111
[build]
1212
postgres_version = "15"
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
DROP function vectorize."table";
22

3-
-- src/api.rs:14
3+
-- src/api.rs:15
44
-- vectorize::api::table
55
CREATE FUNCTION vectorize."table"(
66
"table" TEXT, /* &str */
@@ -10,7 +10,7 @@ CREATE FUNCTION vectorize."table"(
1010
"args" json DEFAULT '{}', /* pgrx::datum::json::Json */
1111
"schema" TEXT DEFAULT 'public', /* alloc::string::String */
1212
"update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */
13-
"transformer" vectorize.Transformer DEFAULT 'text_embedding_ada_002', /* vectorize::types::Transformer */
13+
"transformer" TEXT DEFAULT 'text-embedding-ada-002', /* alloc::string::String */
1414
"search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */
1515
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::types::TableMethod */
1616
"schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */
@@ -19,12 +19,13 @@ STRICT
1919
LANGUAGE c /* Rust */
2020
AS 'MODULE_PATHNAME', 'table_wrapper';
2121

22-
-- src/api.rs:172
22+
DROP FUNCTION vectorize."transform_embeddings";
23+
-- src/api.rs:170
2324
-- vectorize::api::transform_embeddings
2425
CREATE FUNCTION vectorize."transform_embeddings"(
2526
"input" TEXT, /* &str */
26-
"model_name" vectorize.Transformer DEFAULT 'text_embedding_ada_002', /* vectorize::types::Transformer */
27+
"model_name" TEXT DEFAULT 'text-embedding-ada-002', /* alloc::string::String */
2728
"api_key" TEXT DEFAULT NULL /* core::option::Option<alloc::string::String> */
2829
) RETURNS double precision[] /* core::result::Result<alloc::vec::Vec<f64>, pgrx::spi::SpiError> */
2930
LANGUAGE c /* Rust */
30-
AS 'MODULE_PATHNAME', 'transform_embeddings_wrapper';
31+
AS 'MODULE_PATHNAME', 'transform_embeddings_wrapper';

src/api.rs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::executor::VectorizeMeta;
22
use crate::guc;
33
use crate::init;
44
use crate::search::cosine_similarity_search;
5+
use crate::transformers::http_handler::sync_get_model_info;
56
use crate::transformers::{openai, transform};
67
use crate::types;
78
use crate::types::JobParams;
@@ -19,7 +20,7 @@ fn table(
1920
args: default!(pgrx::Json, "'{}'"),
2021
schema: default!(String, "'public'"),
2122
update_col: default!(String, "'last_updated_at'"),
22-
transformer: default!(types::Transformer, "'text_embedding_ada_002'"),
23+
transformer: default!(String, "'text-embedding-ada-002'"),
2324
search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"),
2425
table_method: default!(types::TableMethod, "'append'"),
2526
schedule: default!(String, "'* * * * *'"),
@@ -38,12 +39,12 @@ fn table(
3839

3940
// get prim key type
4041
let pkey_type = init::get_column_datatype(&schema, table, &primary_key);
42+
init::init_pgmq()?;
4143

4244
// certain embedding services require an API key, e.g. openAI
4345
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
44-
init::init_pgmq(&transformer)?;
45-
match transformer {
46-
types::Transformer::text_embedding_ada_002 => {
46+
match transformer.as_ref() {
47+
"text-embedding-ada-002" => {
4748
let openai_key = match api_key {
4849
Some(k) => serde_json::from_value::<String>(k.clone())?,
4950
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
@@ -55,8 +56,10 @@ fn table(
5556
};
5657
openai::validate_api_key(&openai_key)?;
5758
}
58-
// no-op
59-
types::Transformer::all_MiniLM_L12_v2 => (),
59+
t => {
60+
// make sure transformer exists
61+
let _ = sync_get_model_info(t).expect("transformer does not exist");
62+
}
6063
}
6164

6265
let valid_params = types::JobParams {
@@ -105,14 +108,8 @@ fn table(
105108
if ran.is_err() {
106109
error!("error creating job");
107110
}
108-
let init_embed_q = init::init_embedding_table_query(
109-
&job_name,
110-
&schema,
111-
table,
112-
&transformer,
113-
&search_alg,
114-
&table_method,
115-
);
111+
let init_embed_q =
112+
init::init_embedding_table_query(&job_name, &schema, table, &transformer, &table_method);
116113

117114
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
118115
for q in init_embed_q {
@@ -152,7 +149,7 @@ fn search(
152149
let schema = proj_params.schema;
153150
let table = proj_params.table;
154151

155-
let embeddings = transform(query, project_meta.transformer, api_key);
152+
let embeddings = transform(query, &project_meta.transformer, api_key);
156153

157154
let search_results = match project_meta.search_alg {
158155
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
@@ -171,8 +168,8 @@ fn search(
171168
#[pg_extern]
172169
fn transform_embeddings(
173170
input: &str,
174-
model_name: default!(types::Transformer, "'text_embedding_ada_002'"),
171+
model_name: default!(String, "'text-embedding-ada-002'"),
175172
api_key: default!(Option<String>, "NULL"),
176173
) -> Result<Vec<f64>, spi::Error> {
177-
Ok(transform(input, model_name, api_key).remove(0))
174+
Ok(transform(input, &model_name, api_key).remove(0))
178175
}

src/executor.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use pgrx::prelude::*;
22

33
use crate::errors::DatabaseError;
44
use crate::guc::BATCH_SIZE;
5-
use crate::init::QUEUE_MAPPING;
5+
use crate::init::VECTORIZE_QUEUE;
66
use crate::query::check_input;
77
use crate::transformers::types::Inputs;
88
use crate::types;
@@ -23,7 +23,7 @@ pub struct VectorizeMeta {
2323
pub job_id: i64,
2424
pub name: String,
2525
pub job_type: types::JobType,
26-
pub transformer: types::Transformer,
26+
pub transformer: String,
2727
pub search_alg: types::SimilarityAlg,
2828
pub params: serde_json::Value,
2929
#[serde(deserialize_with = "from_tsopt")]
@@ -112,11 +112,8 @@ fn job_execute(job_name: String) {
112112
job_meta: meta.clone(),
113113
inputs: b,
114114
};
115-
let queue_name = QUEUE_MAPPING
116-
.get(&meta.transformer)
117-
.expect("invalid transformer");
118115
let msg_id = queue
119-
.send(queue_name, &msg)
116+
.send(VECTORIZE_QUEUE, &msg)
120117
.await
121118
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
122119
log!("message sent: {}", msg_id);

src/init.rs

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,32 @@
1-
use crate::{query::check_input, types, types::TableMethod, types::Transformer};
1+
use crate::{
2+
query::check_input,
3+
transformers::{http_handler::sync_get_model_info, types::TransformerMetadata},
4+
types,
5+
types::TableMethod,
6+
};
27
use pgrx::prelude::*;
3-
use std::collections::HashMap;
48

59
use anyhow::{Context, Result};
6-
use lazy_static::lazy_static;
710

8-
lazy_static! {
9-
// each model has its own job queue
10-
// maintain the mapping of transformer to queue name here
11-
pub static ref QUEUE_MAPPING: HashMap<Transformer, &'static str> = {
12-
let mut m = HashMap::new();
13-
m.insert(Transformer::text_embedding_ada_002, "v_openai");
14-
m.insert(Transformer::all_MiniLM_L12_v2, "v_all_MiniLM_L12_v2");
15-
m
16-
};
17-
}
11+
pub static VECTORIZE_QUEUE: &str = "vectorize_jobs";
1812

19-
pub fn init_pgmq(transformer: &Transformer) -> Result<()> {
20-
let qname = QUEUE_MAPPING.get(transformer).expect("invalid transformer");
13+
pub fn init_pgmq() -> Result<()> {
2114
// check if queue already created:
2215
let queue_exists: bool = Spi::get_one(&format!(
23-
"SELECT EXISTS (SELECT 1 FROM pgmq.meta WHERE queue_name = '{qname}');",
16+
"SELECT EXISTS (SELECT 1 FROM pgmq.meta WHERE queue_name = '{VECTORIZE_QUEUE}');",
2417
))?
2518
.context("error checking if queue exists")?;
2619
if queue_exists {
20+
info!("queue already exists");
2721
return Ok(());
2822
} else {
23+
info!("creating queue;");
2924
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
30-
let _r = c.update(&format!("SELECT pgmq.create('{qname}');"), None, None)?;
25+
let _r = c.update(
26+
&format!("SELECT pgmq.create('{VECTORIZE_QUEUE}');"),
27+
None,
28+
None,
29+
)?;
3130
Ok(())
3231
});
3332
if let Err(e) = ran {
@@ -69,38 +68,30 @@ pub fn init_embedding_table_query(
6968
job_name: &str,
7069
schema: &str,
7170
table: &str,
72-
transformer: &types::Transformer,
73-
search_alg: &types::SimilarityAlg,
71+
transformer: &str,
7472
transform_method: &TableMethod,
7573
) -> Vec<String> {
76-
// TODO: when adding support for other models, add the output dimension to the transformer attributes
77-
// so that they can be read here, not hard-coded here below
78-
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
79-
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
80-
8174
check_input(job_name).expect("invalid job name");
82-
let col_type = match (transformer, search_alg) {
83-
// TODO: when adding support for other models, add the output dimension to the transformer attributes
84-
// so that they can be read here, not hard-coded here below
85-
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
75+
let col_type = match transformer {
8676
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
87-
(
88-
types::Transformer::text_embedding_ada_002,
89-
types::SimilarityAlg::pgv_cosine_similarity,
90-
) => "vector(1536)",
91-
(types::Transformer::all_MiniLM_L12_v2, types::SimilarityAlg::pgv_cosine_similarity) => {
92-
"vector(384)"
77+
// for anything but OpenAI, first call info endpoint to get the embedding dim of the model
78+
"text-embedding-ada-002" => "vector(1536)".to_owned(),
79+
_ => {
80+
let model_info: TransformerMetadata = sync_get_model_info(transformer)
81+
.expect("failed to call vectorize.embedding_service_url");
82+
let dim = model_info.embedding_dimension;
83+
format!("vector({dim})")
9384
}
9485
};
9586
match transform_method {
9687
TableMethod::append => {
9788
vec![
98-
append_embedding_column(job_name, schema, table, col_type),
89+
append_embedding_column(job_name, schema, table, &col_type),
9990
create_hnsw_cosine_index(job_name, schema, table),
10091
]
10192
}
10293
TableMethod::join => {
103-
vec![create_embedding_table(job_name, col_type)]
94+
vec![create_embedding_table(job_name, &col_type)]
10495
}
10596
}
10697
}
@@ -125,11 +116,6 @@ fn create_hnsw_cosine_index(job_name: &str, schema: &str, table: &str) -> String
125116
}
126117

127118
fn append_embedding_column(job_name: &str, schema: &str, table: &str, col_type: &str) -> String {
128-
// TODO: when adding support for other models, add the output dimension to the transformer attributes
129-
// so that they can be read here, not hard-coded here below
130-
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
131-
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
132-
133119
check_input(job_name).expect("invalid job name");
134120
format!(
135121
"

src/search.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ pub fn cosine_similarity_search(
88
num_results: i32,
99
embeddings: &[f64],
1010
) -> Result<Vec<(pgrx::JsonB,)>, spi::Error> {
11-
let emb = serde_json::to_string(&embeddings).expect("failed to serialize embeddings");
1211
let query = format!(
1312
"
1413
SELECT to_jsonb(t)
1514
as results FROM (
1615
SELECT
17-
1 - ({project}_embeddings <=> '{emb}'::vector) AS similarity_score,
16+
1 - ({project}_embeddings <=> $1::vector) AS similarity_score,
1817
{cols}
1918
FROM {schema}.{table}
2019
WHERE {project}_updated_at is NOT NULL
@@ -24,10 +23,16 @@ pub fn cosine_similarity_search(
2423
",
2524
cols = return_columns.join(", "),
2625
);
27-
log!("query: {}", query);
2826
Spi::connect(|client| {
2927
let mut results: Vec<(pgrx::JsonB,)> = Vec::new();
30-
let tup_table = client.select(&query, None, None)?;
28+
let tup_table = client.select(
29+
&query,
30+
None,
31+
Some(vec![(
32+
PgBuiltInOids::FLOAT8ARRAYOID.oid(),
33+
embeddings.into_datum(),
34+
)]),
35+
)?;
3136
for row in tup_table {
3237
match row["results"].value()? {
3338
Some(r) => results.push((r,)),

src/transformers/http_handler.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
use anyhow::Result;
22

3+
use crate::guc;
34
use crate::transformers::types::{
45
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
56
};
67
use pgrx::prelude::*;
78

9+
use super::types::TransformerMetadata;
10+
811
pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
912
resp: reqwest::Response,
1013
method: &'static str,
@@ -59,3 +62,35 @@ pub fn merge_input_output(inputs: Vec<Inputs>, values: Vec<Vec<f64>>) -> Vec<Pai
5962
})
6063
.collect()
6164
}
65+
66+
#[pg_extern]
67+
pub fn mod_info(model_name: &str) -> pgrx::JsonB {
68+
let meta = sync_get_model_info(model_name).unwrap();
69+
pgrx::JsonB(serde_json::to_value(meta).unwrap())
70+
}
71+
72+
pub fn sync_get_model_info(model_name: &str) -> Result<TransformerMetadata> {
73+
let runtime = tokio::runtime::Builder::new_current_thread()
74+
.enable_io()
75+
.enable_time()
76+
.build()
77+
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
78+
let meta = match runtime.block_on(async { get_model_info(model_name).await }) {
79+
Ok(e) => e,
80+
Err(e) => {
81+
error!("error getting embeddings: {}", e);
82+
}
83+
};
84+
Ok(meta)
85+
}
86+
87+
pub async fn get_model_info(model_name: &str) -> Result<TransformerMetadata> {
88+
let svc_url = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
89+
.expect("vectorize.embedding_service_url must be set to a valid service");
90+
let info_url = svc_url.replace("/embeddings", "/info");
91+
let client = reqwest::Client::new();
92+
let req = client.get(info_url).query(&[("model_name", model_name)]);
93+
let resp = req.send().await?;
94+
let meta_response = handle_response::<TransformerMetadata>(resp, "info").await?;
95+
Ok(meta_response)
96+
}

0 commit comments

Comments
 (0)