Skip to content

Commit a2e6da1

Browse files
authored
Initial rag (#52)
* update signature * optionally trim context * handle context limits * bump tomls * cleanup debug * update api name * change to rag * update test name * try dropping * update ci * re-create ext w/ each test * fix format
1 parent 52a2c4c commit a2e6da1

12 files changed

Lines changed: 279 additions & 49 deletions

File tree

.github/workflows/extension_ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
with:
7676
working-directory: ./
7777
- name: Cargo format
78-
run: cargo +nightly fmt --all --check
78+
run: cargo fmt --all --check
7979
- name: Clippy
8080
run: cargo clippy
8181

@@ -123,13 +123,13 @@ jobs:
123123
rm -rf ./target/pgrx-test-data-* || true
124124
- name: unit-test
125125
run: |
126-
cargo pgrx test
126+
make test-unit
127127
- name: integration-test
128128
run: |
129129
pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15)
130130
pg_version=$(/usr/local/bin/stoml Cargo.toml features.default)
131131
echo "\q" | make run
132-
cargo test -- --ignored --test-threads=1
132+
make test-integration
133133
134134
publish:
135135
if: github.event_name == 'release'

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "vectorize"
3-
version = "0.9.0"
3+
version = "0.10.0"
44
edition = "2021"
55
publish = false
66

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@ pgxn-zip: $(DISTNAME)-$(DISTVERSION).zip
2626

2727
clean:
2828
@rm -rf META.json $(DISTNAME)-$(DISTVERSION).zip
29+
30+
31+
test-integration:
32+
cargo test -- --ignored --test-threads=1
33+
34+
test-unit:
35+
cargo pgrx test

Trunk.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
66
homepage = "https://github.com/tembo-io/pg_vectorize"
77
documentation = "https://github.com/tembo-io/pg_vectorize"
88
categories = ["orchestration", "machine_learning"]
9-
version = "0.9.0"
9+
version = "0.10.0"
1010

1111
[build]
1212
postgres_version = "15"

src/api.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ fn transform_embeddings(
8989

9090
#[allow(clippy::too_many_arguments)]
9191
#[pg_extern]
92-
fn chat_table(
92+
fn init_rag(
9393
agent_name: &str,
9494
table_name: &str,
9595
unique_record_id: &str,
@@ -121,14 +121,30 @@ fn chat_table(
121121

122122
/// creates an table indexed with embeddings for chat completion workloads
123123
#[pg_extern]
124-
fn chat(
124+
fn rag(
125125
agent_name: &str,
126126
query: &str,
127+
// chat models: currently only supports gpt 3.5 and 4
128+
// https://platform.openai.com/docs/models/gpt-3-5-turbo
129+
// https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
127130
chat_model: default!(&str, "'gpt-3.5-turbo'"),
131+
// points to the type of prompt template to use
128132
task: default!(&str, "'question_answer'"),
129133
api_key: default!(Option<&str>, "NULL"),
134+
// number of records to include in the context
135+
num_context: default!(i32, 2),
136+
// truncates context to fit the model's context window
137+
force_trim: default!(bool, false),
130138
) -> Result<TableIterator<'static, (name!(chat_results, pgrx::JsonB),)>> {
131-
let resp = call_chat(agent_name, query, chat_model, task, api_key)?;
139+
let resp = call_chat(
140+
agent_name,
141+
query,
142+
chat_model,
143+
task,
144+
api_key,
145+
num_context,
146+
force_trim,
147+
)?;
132148
let iter = vec![(pgrx::JsonB(serde_json::to_value(resp)?),)];
133149
Ok(TableIterator::new(iter))
134150
}

0 commit comments

Comments
 (0)