Skip to content

Commit 3aa3810

Browse files
authored
pull private model in embedding svc (#53)
* handle api key for HF * parse on GET * pass header to embedding svc * refactor
1 parent a2e6da1 commit 3aa3810

11 files changed

Lines changed: 149 additions & 123 deletions

File tree

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ pgrx = "0.11.3"
2626
postgres-types = "0.2.5"
2727
regex = "1.9.2"
2828
reqwest = {version = "0.11.18", features = ["json"] }
29-
serde = "1.0.173"
29+
serde = { version = "1.0.173", features = ["derive"] }
3030
serde_json = "1.0.103"
31-
sqlx = { version = "0.7.2", features = [
31+
sqlx = { version = "0.7.3", features = [
3232
"runtime-tokio-native-tls",
3333
"postgres",
3434
"chrono",
35+
"json"
3536
] }
3637
thiserror = "1.0.44"
3738
tiktoken-rs = "0.5.7"

src/api.rs

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
use crate::chat::call_chat;
2-
use crate::executor::VectorizeMeta;
3-
use crate::search::{cosine_similarity_search, init_table};
2+
use crate::search::{self, init_table};
43
use crate::transformers::transform;
54
use crate::types;
6-
use crate::util;
75

86
use anyhow::Result;
97
use pgrx::prelude::*;
@@ -41,41 +39,14 @@ fn table(
4139

4240
#[pg_extern]
4341
fn search(
44-
job_name: &str,
45-
query: &str,
42+
job_name: String,
43+
query: String,
4644
api_key: default!(Option<String>, "NULL"),
4745
return_columns: default!(Vec<String>, "ARRAY['*']::text[]"),
4846
num_results: default!(i32, 10),
4947
) -> Result<TableIterator<'static, (name!(search_results, pgrx::JsonB),)>> {
50-
let project_meta: VectorizeMeta = if let Ok(Some(js)) = util::get_vectorize_meta_spi(job_name) {
51-
js
52-
} else {
53-
error!("failed to get project metadata");
54-
};
55-
let proj_params: types::JobParams = serde_json::from_value(
56-
serde_json::to_value(project_meta.params).unwrap_or_else(|e| {
57-
error!("failed to serialize metadata: {}", e);
58-
}),
59-
)
60-
.unwrap_or_else(|e| error!("failed to deserialize metadata: {}", e));
61-
62-
let schema = proj_params.schema;
63-
let table = proj_params.table;
64-
65-
let embeddings = transform(query, &project_meta.transformer, api_key);
66-
67-
let search_results = match project_meta.search_alg {
68-
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
69-
job_name,
70-
&schema,
71-
&table,
72-
&return_columns,
73-
num_results,
74-
&embeddings[0],
75-
)?,
76-
};
77-
78-
Ok(TableIterator::new(search_results))
48+
let search_results = search::search(&job_name, &query, api_key, return_columns, num_results)?;
49+
Ok(TableIterator::new(search_results.into_iter().map(|r| (r,))))
7950
}
8051

8152
#[pg_extern]
@@ -127,10 +98,10 @@ fn rag(
12798
// chat models: currently only supports gpt 3.5 and 4
12899
// https://platform.openai.com/docs/models/gpt-3-5-turbo
129100
// https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
130-
chat_model: default!(&str, "'gpt-3.5-turbo'"),
101+
chat_model: default!(String, "'gpt-3.5-turbo'"),
131102
// points to the type of prompt template to use
132-
task: default!(&str, "'question_answer'"),
133-
api_key: default!(Option<&str>, "NULL"),
103+
task: default!(String, "'question_answer'"),
104+
api_key: default!(Option<String>, "NULL"),
134105
// number of records to include in the context
135106
num_context: default!(i32, 2),
136107
// truncates context to fit the model's context window
@@ -139,8 +110,8 @@ fn rag(
139110
let resp = call_chat(
140111
agent_name,
141112
query,
142-
chat_model,
143-
task,
113+
&chat_model,
114+
&task,
144115
api_key,
145116
num_context,
146117
force_trim,

src/chat.rs

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::executor::VectorizeMeta;
22
use crate::guc;
3+
use crate::search;
34
use crate::types;
45
use crate::util::get_vectorize_meta_spi;
56

@@ -39,7 +40,7 @@ pub fn call_chat(
3940
query: &str,
4041
chat_model: &str,
4142
task: &str,
42-
api_key: Option<&str>,
43+
api_key: Option<String>,
4344
num_context: i32,
4445
force_trim: bool,
4546
) -> Result<ChatResponse> {
@@ -60,53 +61,28 @@ pub fn call_chat(
6061
let content_column = job_params.columns[0].clone();
6162
let pk = job_params.primary_key;
6263
let columns = vec![pk.clone(), content_column.clone()];
63-
// query the relevant vectorize table using the query
64-
// TODO: refactor so we can call an internal access vector search function
65-
let search_results: Result<Vec<ContextualSearch>, spi::Error> = Spi::connect(|c| {
66-
let mut results: Vec<ContextualSearch> = Vec::new();
67-
let q = format!(
68-
"
69-
select search_results from vectorize.search(
70-
job_name => '{agent_name}',
71-
query => '{query}',
72-
return_columns => $1,
73-
num_results => {num_context}
74-
)",
75-
);
76-
let tup_table = c.select(
77-
&q,
78-
None,
79-
Some(vec![(
80-
PgBuiltInOids::TEXTARRAYOID.oid(),
81-
columns.into_datum(),
82-
)]),
83-
)?;
84-
85-
for row in tup_table {
86-
let row_pgrx_js: pgrx::JsonB = row.get_by_name("search_results").unwrap().unwrap();
87-
let row_js: serde_json::Value = row_pgrx_js.0;
88-
89-
let record_id = row_js
90-
.get(&pk)
91-
.unwrap_or_else(|| error!("`{pk}` not found"));
92-
let content = row_js
93-
.get(&content_column)
94-
.unwrap_or_else(|| error!("`{content_column}` not found"));
95-
let text_content =
96-
serde_json::to_string(content).expect("failed to serialize content to string");
97-
98-
let token_ct = bpe.encode_ordinary(&text_content).len() as i32;
99-
results.push(ContextualSearch {
100-
record_id: serde_json::to_string(record_id)
101-
.expect("failed to serialize record_id to string"),
102-
content: text_content,
103-
token_ct,
104-
});
105-
}
106-
Ok(results)
107-
});
10864

109-
let search_results = search_results?;
65+
let raw_search = search::search(agent_name, query, api_key.clone(), columns, num_context)?;
66+
67+
let mut search_results: Vec<ContextualSearch> = Vec::new();
68+
for s in raw_search {
69+
let row_js: serde_json::Value = s.0;
70+
let record_id = row_js
71+
.get(&pk)
72+
.unwrap_or_else(|| error!("`{pk}` not found"));
73+
let content = row_js
74+
.get(&content_column)
75+
.unwrap_or_else(|| error!("`{content_column}` not found"));
76+
let text_content =
77+
serde_json::to_string(content).expect("failed to serialize content to string");
78+
let token_ct = bpe.encode_ordinary(&text_content).len() as i32;
79+
search_results.push(ContextualSearch {
80+
record_id: serde_json::to_string(record_id)
81+
.expect("failed to serialize record_id to string"),
82+
content: text_content,
83+
token_ct,
84+
});
85+
}
11086

11187
// read prompt template
11288
let res_prompts: Result<PromptTemplate, spi::Error> = Spi::connect(|c| {
@@ -165,7 +141,7 @@ fn render_user_message(user_prompt_template: &str, context: &str, query: &str) -
165141
fn call_chat_completions(
166142
prompts: RenderedPrompt,
167143
model: &str,
168-
api_key: Option<&str>,
144+
api_key: Option<String>,
169145
) -> Result<String> {
170146
let openai_key = match api_key {
171147
Some(k) => k.to_string(),

src/init.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,15 @@ pub fn init_embedding_table_query(
7070
table: &str,
7171
transformer: &str,
7272
transform_method: &TableMethod,
73+
api_key: Option<String>,
7374
) -> Vec<String> {
7475
check_input(job_name).expect("invalid job name");
7576
let col_type = match transformer {
7677
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
7778
// for anything but OpenAI, first call info endpoint to get the embedding dim of the model
7879
"text-embedding-ada-002" => "vector(1536)".to_owned(),
7980
_ => {
80-
let model_info: TransformerMetadata = sync_get_model_info(transformer)
81+
let model_info: TransformerMetadata = sync_get_model_info(transformer, api_key)
8182
.expect("failed to call vectorize.embedding_service_url");
8283
let dim = model_info.embedding_dimension;
8384
format!("vector({dim})")

src/search.rs

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ use crate::init::{self, VECTORIZE_QUEUE};
44
use crate::job::{create_insert_trigger, create_trigger_handler, create_update_trigger};
55
use crate::transformers::http_handler::sync_get_model_info;
66
use crate::transformers::openai;
7+
use crate::transformers::transform;
78
use crate::transformers::types::Inputs;
89
use crate::types;
10+
use crate::util;
911

1012
use anyhow::Result;
1113
use pgrx::prelude::*;
@@ -28,15 +30,16 @@ pub fn init_table(
2830
) -> Result<String> {
2931
let job_type = types::JobType::Columns;
3032

31-
// write job to table
32-
let init_job_q = init::init_job_query();
3333
let arguments = match serde_json::to_value(args) {
3434
Ok(a) => a,
3535
Err(e) => {
3636
error!("invalid json for argument `args`: {}", e);
3737
}
3838
};
39-
let api_key = arguments.get("api_key");
39+
let api_key = match arguments.get("api_key") {
40+
Some(k) => Some(serde_json::from_value::<String>(k.clone())?),
41+
None => None,
42+
};
4043

4144
// get prim key type
4245
let pkey_type = init::get_column_datatype(schema, table, primary_key);
@@ -46,8 +49,8 @@ pub fn init_table(
4649
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
4750
match transformer {
4851
"text-embedding-ada-002" => {
49-
let openai_key = match api_key {
50-
Some(k) => serde_json::from_value::<String>(k.clone())?,
52+
let openai_key = match api_key.clone() {
53+
Some(k) => k,
5154
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
5255
Some(k) => k,
5356
None => {
@@ -59,7 +62,7 @@ pub fn init_table(
5962
}
6063
t => {
6164
// make sure transformer exists
62-
let _ = sync_get_model_info(t).expect("transformer does not exist");
65+
let _ = sync_get_model_info(t, api_key.clone()).expect("transformer does not exist");
6366
}
6467
}
6568

@@ -71,12 +74,13 @@ pub fn init_table(
7174
table_method: table_method.clone(),
7275
primary_key: primary_key.to_string(),
7376
pkey_type,
74-
api_key: api_key
75-
.map(|k| serde_json::from_value::<String>(k.clone()).expect("error parsing api key")),
77+
api_key: api_key.clone(),
7678
};
7779
let params =
7880
pgrx::JsonB(serde_json::to_value(valid_params.clone()).expect("error serializing params"));
7981

82+
// write job to table
83+
let init_job_q = init::init_job_query();
8084
// using SPI here because it is unlikely that this code will be run anywhere but inside the extension.
8185
// background worker will likely be moved to an external container or service in near future
8286
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
@@ -110,8 +114,14 @@ pub fn init_table(
110114
if ran.is_err() {
111115
error!("error creating job");
112116
}
113-
let init_embed_q =
114-
init::init_embedding_table_query(job_name, schema, table, transformer, &table_method);
117+
let init_embed_q = init::init_embedding_table_query(
118+
job_name,
119+
schema,
120+
table,
121+
transformer,
122+
&table_method,
123+
api_key,
124+
);
115125

116126
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
117127
for q in init_embed_q {
@@ -198,14 +208,50 @@ pub fn init_table(
198208
Ok(format!("Successfully created job: {job_name}"))
199209
}
200210

211+
pub fn search(
212+
job_name: &str,
213+
query: &str,
214+
api_key: Option<String>,
215+
return_columns: Vec<String>,
216+
num_results: i32,
217+
) -> Result<Vec<pgrx::JsonB>> {
218+
let project_meta: VectorizeMeta = if let Ok(Some(js)) = util::get_vectorize_meta_spi(job_name) {
219+
js
220+
} else {
221+
error!("failed to get project metadata");
222+
};
223+
let proj_params: types::JobParams = serde_json::from_value(
224+
serde_json::to_value(project_meta.params).unwrap_or_else(|e| {
225+
error!("failed to serialize metadata: {}", e);
226+
}),
227+
)
228+
.unwrap_or_else(|e| error!("failed to deserialize metadata: {}", e));
229+
230+
let schema = proj_params.schema;
231+
let table = proj_params.table;
232+
233+
let embeddings = transform(query, &project_meta.transformer, api_key);
234+
235+
match project_meta.search_alg {
236+
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
237+
job_name,
238+
&schema,
239+
&table,
240+
&return_columns,
241+
num_results,
242+
&embeddings[0],
243+
),
244+
}
245+
}
246+
201247
pub fn cosine_similarity_search(
202248
project: &str,
203249
schema: &str,
204250
table: &str,
205251
return_columns: &[String],
206252
num_results: i32,
207253
embeddings: &[f64],
208-
) -> Result<Vec<(pgrx::JsonB,)>, spi::Error> {
254+
) -> Result<Vec<pgrx::JsonB>> {
209255
let query = format!(
210256
"
211257
SELECT to_jsonb(t)
@@ -222,7 +268,8 @@ pub fn cosine_similarity_search(
222268
cols = return_columns.join(", "),
223269
);
224270
Spi::connect(|client| {
225-
let mut results: Vec<(pgrx::JsonB,)> = Vec::new();
271+
// let mut results: Vec<(pgrx::JsonB,)> = Vec::new();
272+
let mut results: Vec<pgrx::JsonB> = Vec::new();
226273
let tup_table = client.select(
227274
&query,
228275
None,
@@ -233,7 +280,7 @@ pub fn cosine_similarity_search(
233280
)?;
234281
for row in tup_table {
235282
match row["results"].value()? {
236-
Some(r) => results.push((r,)),
283+
Some(r) => results.push(r),
237284
None => error!("failed to get results"),
238285
}
239286
}

0 commit comments

Comments
 (0)