Skip to content

Commit 1e194c7

Browse files
authored
Added Ollama support in transform function (#106)
* Embeddings generation * Ollama Embeddings generator added Embeddings can also be generated using Ollama models using the transform function. * Merged with main and removed ollama insertion in backwards compatibility block * Updated search.rs with main * Resolved conflicts * Updated branch with main * Fixed formatting * fix pgmq install * fmt * fix path for pgmq install
1 parent aae5a6d commit 1e194c7

5 files changed

Lines changed: 107 additions & 6 deletions

File tree

core/src/transformers/ollama.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use anyhow::Result;
22
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
33
use url::Url;
44

5+
use super::types::EmbeddingRequest;
6+
57
pub struct OllamaInstance {
68
pub model_name: String,
79
pub instance: Ollama,
@@ -11,6 +13,8 @@ pub trait LLMFunctions {
1113
fn new(model_name: String, url: String) -> Self;
1214
#[allow(async_fn_in_trait)]
1315
async fn generate_reponse(&self, prompt_text: String) -> Result<String, String>;
16+
#[allow(async_fn_in_trait)]
17+
async fn generate_embedding(&self, inputs: String) -> Result<Vec<f64>, String>;
1418
}
1519

1620
impl LLMFunctions for OllamaInstance {
@@ -38,6 +42,16 @@ impl LLMFunctions for OllamaInstance {
3842
Err(e) => Err(e.to_string()),
3943
}
4044
}
45+
async fn generate_embedding(&self, input: String) -> Result<Vec<f64>, String> {
46+
let embed = self
47+
.instance
48+
.generate_embeddings(self.model_name.clone(), input, None)
49+
.await;
50+
match embed {
51+
Ok(res) => Ok(res.embeddings),
52+
Err(e) => Err(e.to_string()),
53+
}
54+
}
4155
}
4256

4357
pub fn ollama_embedding_dim(model_name: &str) -> i32 {
@@ -46,3 +60,41 @@ pub fn ollama_embedding_dim(model_name: &str) -> i32 {
4660
_ => 1536,
4761
}
4862
}
63+
64+
pub fn check_model_host(url: &str) -> Result<String, String> {
65+
let runtime = tokio::runtime::Builder::new_current_thread()
66+
.enable_io()
67+
.enable_time()
68+
.build()
69+
.unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e));
70+
71+
runtime.block_on(async {
72+
let response = reqwest::get(url).await.unwrap();
73+
match response.status() {
74+
reqwest::StatusCode::OK => Ok(format!("Success! {:?}", response)),
75+
_ => Err(format!("Error! {:?}", response)),
76+
}
77+
})
78+
}
79+
80+
pub fn generate_embeddings(request: EmbeddingRequest) -> Result<Vec<Vec<f64>>> {
81+
let runtime = tokio::runtime::Builder::new_current_thread()
82+
.enable_io()
83+
.enable_time()
84+
.build()
85+
.unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e));
86+
87+
runtime.block_on(async {
88+
let instance = OllamaInstance::new(request.payload.model, request.url);
89+
let mut embeddings: Vec<Vec<f64>> = vec![];
90+
for input in request.payload.input {
91+
let response = instance.generate_embedding(input).await;
92+
let embedding = match response {
93+
Ok(embed) => embed,
94+
Err(e) => panic!("Unable to generate embeddings.\nError: {:?}", e),
95+
};
96+
embeddings.push(embedding);
97+
}
98+
Ok(embeddings)
99+
})
100+
}

core/src/types.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use chrono::serde::ts_seconds_option::deserialize as from_tsopt;
2+
23
use serde::{Deserialize, Serialize};
34
use sqlx::types::chrono::Utc;
45
use sqlx::FromRow;
@@ -168,10 +169,12 @@ pub enum ModelError {
168169
impl Model {
169170
pub fn new(input: &str) -> Result<Self, ModelError> {
170171
let mut parts: Vec<&str> = input.split('/').collect();
172+
171173
let missing_source = parts.len() < 2;
172174
if parts.len() > 3 {
173175
return Err(ModelError::InvalidFormat(input.to_string()));
174176
}
177+
175178
if missing_source && parts[0] == "text-embedding-ada-002" {
176179
// for backwards compatibility, prepend "openai" to text-embedding-ada-2
177180
parts.insert(0, "openai");

extension/Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ install-pgvector:
7373

7474
install-pgmq:
7575
git clone https://github.com/tembo-io/pgmq.git && \
76-
cd pgmq && \
77-
PG_CONFIG=${PGRX_PG_CONFIG} make clean && \
76+
cd pgmq/pgmq-extension && \
7877
PG_CONFIG=${PGRX_PG_CONFIG} make && \
7978
PG_CONFIG=${PGRX_PG_CONFIG} make install && \
8079
cd .. && rm -rf pgmq

extension/src/search.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::util;
88

99
use anyhow::{Context, Result};
1010
use pgrx::prelude::*;
11+
use vectorize_core::transformers::ollama::check_model_host;
1112
use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta};
1213

1314
#[allow(clippy::too_many_arguments)]
@@ -69,9 +70,26 @@ pub fn init_table(
6970
sync_get_model_info(&transformer.fullname, api_key.clone())
7071
.context("transformer does not exist")?;
7172
}
72-
ModelSource::Ollama | ModelSource::Tembo => {
73+
ModelSource::Tembo => {
7374
error!("Ollama/Tembo not implemented for search yet");
7475
}
76+
ModelSource::Ollama => {
77+
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
78+
Some(k) => k,
79+
None => {
80+
error!("failed to get Ollama url from GUC");
81+
}
82+
};
83+
let res = check_model_host(&url);
84+
match res {
85+
Ok(_) => {
86+
info!("Model host active!")
87+
}
88+
Err(e) => {
89+
error!("Error with model host: {:?}", e)
90+
}
91+
}
92+
}
7593
}
7694

7795
let valid_params = types::JobParams {

extension/src/transformers/mod.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use generic::get_env_interpolated_guc;
77
use pgrx::prelude::*;
88

99
use vectorize_core::transformers::http_handler::openai_embedding_request;
10+
use vectorize_core::transformers::ollama::generate_embeddings;
1011
use vectorize_core::transformers::openai::OPENAI_BASE_URL;
1112
use vectorize_core::transformers::types::{EmbeddingPayload, EmbeddingRequest};
1213
use vectorize_core::types::{Model, ModelSource};
@@ -61,14 +62,38 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
6162
api_key: api_key.map(|s| s.to_string()),
6263
}
6364
}
64-
ModelSource::Ollama => error!("Ollama transformer not implemented yet"),
65+
ModelSource::Ollama => {
66+
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
67+
Some(k) => k,
68+
None => {
69+
error!("failed to get Ollama url from GUC");
70+
}
71+
};
72+
73+
let embedding_request = EmbeddingPayload {
74+
input: vec![input.to_string()],
75+
model: transformer.name.to_string(),
76+
};
77+
78+
EmbeddingRequest {
79+
url,
80+
payload: embedding_request,
81+
api_key: None,
82+
}
83+
}
6584
};
6685
let timeout = EMBEDDING_REQ_TIMEOUT_SEC.get();
6786

6887
match transformer.source {
69-
ModelSource::Ollama | ModelSource::Tembo => {
70-
error!("Ollama/Tembo transformer not implemented yet")
88+
ModelSource::Ollama => {
89+
// Call the embeddings generation function
90+
let embeddings = generate_embeddings(embedding_request);
91+
match embeddings {
92+
Ok(k) => k,
93+
Err(e) => error!("error getting embeddings: {}", e),
94+
}
7195
}
96+
7297
ModelSource::OpenAI | ModelSource::SentenceTransformers => {
7398
match runtime
7499
.block_on(async { openai_embedding_request(embedding_request, timeout).await })
@@ -79,5 +104,9 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
79104
}
80105
}
81106
}
107+
108+
ModelSource::Tembo => {
109+
error!("Embeddings support not added for Tembo yet!")
110+
}
82111
}
83112
}

0 commit comments

Comments
 (0)