Skip to content

Commit 1e47832

Browse files
authored
delete instead of archive, batch OpenAI requests (#129)
* delete instead of archive * vectorscale optional * break up large openai requests * Delete core/src/transformers/debug-test.py
1 parent 1e194c7 commit 1e47832

5 files changed

Lines changed: 50 additions & 21 deletions

File tree

core/src/transformers/http_handler.rs

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use anyhow::Result;
2-
31
use crate::transformers::types::{
42
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
53
};
4+
use anyhow::Result;
65
pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
76
resp: reqwest::Response,
87
method: &'static str,
@@ -26,22 +25,47 @@ pub async fn openai_embedding_request(
2625
timeout: i32,
2726
) -> Result<Vec<Vec<f64>>> {
2827
let client = reqwest::Client::new();
29-
let mut req = client
30-
.post(request.url)
31-
.timeout(std::time::Duration::from_secs(timeout as u64))
32-
.json::<EmbeddingPayload>(&request.payload)
33-
.header("Content-Type", "application/json");
34-
if let Some(key) = request.api_key {
35-
req = req.header("Authorization", format!("Bearer {}", key));
28+
29+
// openai request size limit is 2048 inputs
30+
let number_inputs = request.payload.input.len();
31+
let todo_requests: Vec<EmbeddingPayload> = if number_inputs > 2048 {
32+
split_vector(request.payload.input, 2048)
33+
.iter()
34+
.map(|chunk| EmbeddingPayload {
35+
input: chunk.clone(),
36+
model: request.payload.model.clone(),
37+
})
38+
.collect()
39+
} else {
40+
vec![request.payload]
41+
};
42+
43+
let mut all_embeddings: Vec<Vec<f64>> = Vec::with_capacity(number_inputs);
44+
45+
for request_payload in todo_requests.iter() {
46+
let mut req = client
47+
.post(&request.url)
48+
.timeout(std::time::Duration::from_secs(timeout as u64))
49+
.json::<EmbeddingPayload>(request_payload)
50+
.header("Content-Type", "application/json");
51+
if let Some(key) = request.api_key.as_ref() {
52+
req = req.header("Authorization", format!("Bearer {}", key));
53+
}
54+
let resp = req.send().await?;
55+
let embedding_resp: EmbeddingResponse =
56+
handle_response::<EmbeddingResponse>(resp, "embeddings").await?;
57+
let embeddings: Vec<Vec<f64>> = embedding_resp
58+
.data
59+
.iter()
60+
.map(|d| d.embedding.clone())
61+
.collect();
62+
all_embeddings.extend(embeddings);
3663
}
37-
let resp = req.send().await?;
38-
let embedding_resp = handle_response::<EmbeddingResponse>(resp, "embeddings").await?;
39-
let embeddings = embedding_resp
40-
.data
41-
.iter()
42-
.map(|d| d.embedding.clone())
43-
.collect();
44-
Ok(embeddings)
64+
Ok(all_embeddings)
65+
}
66+
67+
fn split_vector(vec: Vec<String>, chunk_size: usize) -> Vec<Vec<String>> {
68+
vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect()
4569
}
4670

4771
// merges the vec of inputs with the embedding responses

extension/Trunk.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
66
homepage = "https://github.com/tembo-io/pg_vectorize"
77
documentation = "https://github.com/tembo-io/pg_vectorize"
88
categories = ["orchestration", "machine_learning"]
9-
version = "0.16.0"
9+
version = "0.17.0"
1010

1111
[build]
1212
postgres_version = "15"

extension/src/transformers/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,13 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
4343
input: vec![input.to_string()],
4444
model: transformer.name.to_string(),
4545
};
46+
47+
let url = match guc::get_guc(guc::VectorizeGuc::OpenAIServiceUrl) {
48+
Some(k) => k,
49+
None => OPENAI_BASE_URL.to_string(),
50+
};
4651
EmbeddingRequest {
47-
url: format!("{OPENAI_BASE_URL}/embeddings"),
52+
url: format!("{url}/embeddings"),
4853
payload: embedding_request,
4954
api_key: Some(api_key.to_string()),
5055
}

extension/src/workers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub async fn run_worker(
5151

5252
// delete message from queue
5353
if delete_it {
54-
match queue.archive(queue_name, msg_id).await {
54+
match queue.delete(queue_name, msg_id).await {
5555
Ok(_) => {
5656
info!("pg-vectorize: deleted message: {}", msg_id);
5757
}

extension/vectorize.control

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ module_pathname = '$libdir/vectorize'
44
relocatable = false
55
superuser = true
66
schema = 'vectorize'
7-
requires = 'pg_cron,pgmq,vector,vectorscale'
7+
requires = 'pg_cron,pgmq,vector'

0 commit comments

Comments
 (0)