@@ -4,8 +4,10 @@ use crate::init::{self, VECTORIZE_QUEUE};
44use crate :: job:: { create_insert_trigger, create_trigger_handler, create_update_trigger} ;
55use crate :: transformers:: http_handler:: sync_get_model_info;
66use crate :: transformers:: openai;
7+ use crate :: transformers:: transform;
78use crate :: transformers:: types:: Inputs ;
89use crate :: types;
10+ use crate :: util;
911
1012use anyhow:: Result ;
1113use pgrx:: prelude:: * ;
@@ -28,15 +30,16 @@ pub fn init_table(
2830) -> Result < String > {
2931 let job_type = types:: JobType :: Columns ;
3032
31- // write job to table
32- let init_job_q = init:: init_job_query ( ) ;
3333 let arguments = match serde_json:: to_value ( args) {
3434 Ok ( a) => a,
3535 Err ( e) => {
3636 error ! ( "invalid json for argument `args`: {}" , e) ;
3737 }
3838 } ;
39- let api_key = arguments. get ( "api_key" ) ;
39+ let api_key = match arguments. get ( "api_key" ) {
40+ Some ( k) => Some ( serde_json:: from_value :: < String > ( k. clone ( ) ) ?) ,
41+ None => None ,
42+ } ;
4043
4144 // get prim key type
4245 let pkey_type = init:: get_column_datatype ( schema, table, primary_key) ;
@@ -46,8 +49,8 @@ pub fn init_table(
4649 // key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
4750 match transformer {
4851 "text-embedding-ada-002" => {
49- let openai_key = match api_key {
50- Some ( k) => serde_json :: from_value :: < String > ( k . clone ( ) ) ? ,
52+ let openai_key = match api_key. clone ( ) {
53+ Some ( k) => k ,
5154 None => match guc:: get_guc ( guc:: VectorizeGuc :: OpenAIKey ) {
5255 Some ( k) => k,
5356 None => {
@@ -59,7 +62,7 @@ pub fn init_table(
5962 }
6063 t => {
6164 // make sure transformer exists
62- let _ = sync_get_model_info ( t) . expect ( "transformer does not exist" ) ;
65+ let _ = sync_get_model_info ( t, api_key . clone ( ) ) . expect ( "transformer does not exist" ) ;
6366 }
6467 }
6568
@@ -71,12 +74,13 @@ pub fn init_table(
7174 table_method : table_method. clone ( ) ,
7275 primary_key : primary_key. to_string ( ) ,
7376 pkey_type,
74- api_key : api_key
75- . map ( |k| serde_json:: from_value :: < String > ( k. clone ( ) ) . expect ( "error parsing api key" ) ) ,
77+ api_key : api_key. clone ( ) ,
7678 } ;
7779 let params =
7880 pgrx:: JsonB ( serde_json:: to_value ( valid_params. clone ( ) ) . expect ( "error serializing params" ) ) ;
7981
82+ // write job to table
83+ let init_job_q = init:: init_job_query ( ) ;
8084 // using SPI here because it is unlikely that this code will be run anywhere but inside the extension.
8185 // background worker will likely be moved to an external container or service in near future
8286 let ran: Result < _ , spi:: Error > = Spi :: connect ( |mut c| {
@@ -110,8 +114,14 @@ pub fn init_table(
110114 if ran. is_err ( ) {
111115 error ! ( "error creating job" ) ;
112116 }
113- let init_embed_q =
114- init:: init_embedding_table_query ( job_name, schema, table, transformer, & table_method) ;
117+ let init_embed_q = init:: init_embedding_table_query (
118+ job_name,
119+ schema,
120+ table,
121+ transformer,
122+ & table_method,
123+ api_key,
124+ ) ;
115125
116126 let ran: Result < _ , spi:: Error > = Spi :: connect ( |mut c| {
117127 for q in init_embed_q {
@@ -198,14 +208,50 @@ pub fn init_table(
198208 Ok ( format ! ( "Successfully created job: {job_name}" ) )
199209}
200210
211+ pub fn search (
212+ job_name : & str ,
213+ query : & str ,
214+ api_key : Option < String > ,
215+ return_columns : Vec < String > ,
216+ num_results : i32 ,
217+ ) -> Result < Vec < pgrx:: JsonB > > {
218+ let project_meta: VectorizeMeta = if let Ok ( Some ( js) ) = util:: get_vectorize_meta_spi ( job_name) {
219+ js
220+ } else {
221+ error ! ( "failed to get project metadata" ) ;
222+ } ;
223+ let proj_params: types:: JobParams = serde_json:: from_value (
224+ serde_json:: to_value ( project_meta. params ) . unwrap_or_else ( |e| {
225+ error ! ( "failed to serialize metadata: {}" , e) ;
226+ } ) ,
227+ )
228+ . unwrap_or_else ( |e| error ! ( "failed to deserialize metadata: {}" , e) ) ;
229+
230+ let schema = proj_params. schema ;
231+ let table = proj_params. table ;
232+
233+ let embeddings = transform ( query, & project_meta. transformer , api_key) ;
234+
235+ match project_meta. search_alg {
236+ types:: SimilarityAlg :: pgv_cosine_similarity => cosine_similarity_search (
237+ job_name,
238+ & schema,
239+ & table,
240+ & return_columns,
241+ num_results,
242+ & embeddings[ 0 ] ,
243+ ) ,
244+ }
245+ }
246+
201247pub fn cosine_similarity_search (
202248 project : & str ,
203249 schema : & str ,
204250 table : & str ,
205251 return_columns : & [ String ] ,
206252 num_results : i32 ,
207253 embeddings : & [ f64 ] ,
208- ) -> Result < Vec < ( pgrx:: JsonB , ) > , spi :: Error > {
254+ ) -> Result < Vec < pgrx:: JsonB > > {
209255 let query = format ! (
210256 "
211257 SELECT to_jsonb(t)
@@ -222,7 +268,8 @@ pub fn cosine_similarity_search(
222268 cols = return_columns. join( ", " ) ,
223269 ) ;
224270 Spi :: connect ( |client| {
225- let mut results: Vec < ( pgrx:: JsonB , ) > = Vec :: new ( ) ;
271+ // let mut results: Vec<(pgrx::JsonB,)> = Vec::new();
272+ let mut results: Vec < pgrx:: JsonB > = Vec :: new ( ) ;
226273 let tup_table = client. select (
227274 & query,
228275 None ,
@@ -233,7 +280,7 @@ pub fn cosine_similarity_search(
233280 ) ?;
234281 for row in tup_table {
235282 match row[ "results" ] . value ( ) ? {
236- Some ( r) => results. push ( ( r , ) ) ,
283+ Some ( r) => results. push ( r ) ,
237284 None => error ! ( "failed to get results" ) ,
238285 }
239286 }
0 commit comments