diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..485201f --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rusqlite-vss" +version = "0.1.0" +edition = "2021" + +[dependencies] +# rusqlite = { version = "0.31.0", features=["bundled"] } +rusqlite = { version = "0.29.0", features = ["bundled"] } +sqlite-vss = { version = "0.1.2", features = ["download-libs"] } +tokio = { version = "1.37.0", features = ["full"] } +axum = "0.7.5" +anyhow = "1.0.86" +serde = { version = "1", features = ["derive"] } +serde_json = "1.0.117" diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..457702a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use axum::{ + routing::{delete, get, post, put}, + Router, +}; +use tokio::sync::Mutex; + +pub mod service; +pub mod store; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + store::init(); + + let addr = std::env::var("LISTEN_ADDR").unwrap_or("0.0.0.0:3000".to_string()); + + let db = store::open("store.sqlite")?; + + let app = Router::new() + .route("/collections/:name", put(service::create_collections)) + .route("/collections/:name", get(service::get_collections_info)) + .route("/collections/:name", delete(service::delete_collection)) + .route("/collections/:name/points", put(service::add_points)) + .route( + "/collections/:name/points/delete", + post(service::delete_points), + ) + .route( + "/collections/:name/points/:point_id", + get(service::get_point), + ) + .route("/collections/:name/points", post(service::get_points)) + .with_state(Arc::new(Mutex::new(db))); + + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 0000000..532c3bd --- /dev/null +++ b/src/service.rs @@ -0,0 +1,328 @@ +use std::sync::Arc; + +use axum::{ + extract::{Path, State}, + response::IntoResponse, + Json, +}; + +use rusqlite::OptionalExtension; +use tokio::sync::Mutex; + +use crate::store; + +#[derive(Debug, serde::Serialize)] +pub struct APIResult { + pub result: T, + pub status: Option, + pub error: Option, +} + +#[derive(Debug, serde::Deserialize)] +pub struct CreateConllections { + pub vectors: CreateConllectionsVectors, +} + +#[derive(Debug, serde::Deserialize)] +pub struct CreateConllectionsVectors { + pub size: usize, +} + +pub type CreateConllectionsResult = APIResult; + +pub async fn create_collections( + Path(name): Path, + State(db): State>>, + Json(create_conllections): Json, +) -> impl IntoResponse { + let conn = db.lock().await; + if let Err(e) = store::create_collections(&conn, &name, create_conllections.vectors.size) { + return ( + axum::http::StatusCode::CONFLICT, + Json(CreateConllectionsResult { + result: false, + status: None, + error: Some(e.to_string()), + }), + ); + } else { + return ( + axum::http::StatusCode::OK, + Json(CreateConllectionsResult { + result: true, + status: Some("ok".to_string()), + error: None, + }), + ); + } +} + +#[derive(Debug, serde::Serialize)] +pub struct CollectionsInfo { + pub points_count: u64, +} + +pub type GetCollectionsResult = APIResult; + +pub async fn get_collections_info( + Path(name): Path, + State(db): State>>, +) -> impl IntoResponse { + let conn = db.lock().await; + match store::get_collections_info(&conn, &name) { + Ok(info) => ( + axum::http::StatusCode::OK, + Json(GetCollectionsResult { + result: CollectionsInfo { + points_count: info.points_count, + }, + status: Some("ok".to_string()), + error: None, + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(GetCollectionsResult { + result: CollectionsInfo { points_count: 0 }, + status: None, + error: Some(e.to_string()), + }), + ), + } +} + +#[derive(Debug, serde::Deserialize)] +pub struct AddPoints { + pub points: Vec, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct Point { + pub id: u64, + pub vector: Vec, + pub payload: Option>, +} + +pub type AddPointsResult = APIResult>>; + +pub async fn add_points( + Path(name): Path, + State(db): State>>, + Json(points): Json, +) -> impl IntoResponse { + { + let conn = db.lock().await; + match store::add_point(&conn, &name, &points.points) { + Ok(success_id) => ( + axum::http::StatusCode::OK, + Json(AddPointsResult { + result: Some(success_id), + status: Some("ok".to_string()), + error: None, + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(AddPointsResult { + result: None, + status: None, + error: Some(e.to_string()), + }), + ), + } + } +} + +#[derive(Debug, serde::Deserialize)] +pub struct GetPoints { + ids: Vec, +} + +pub type GetPointsResult = APIResult>>; + +pub async fn get_points( + Path(name): Path, + State(db): State>>, + Json(ids): Json, +) -> impl IntoResponse { + let r = { + let conn = db.lock().await; + store::get_points(&conn, &name, ids.ids).optional() + }; + + match r { + Ok(Some(points)) => ( + axum::http::StatusCode::OK, + Json(GetPointsResult { + result: Some(points), + status: Some("ok".to_string()), + error: None, + }), + ), + Ok(None) => ( + axum::http::StatusCode::OK, + Json(GetPointsResult { + result: Some(Vec::new()), + status: Some("ok".to_string()), + error: None, + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(GetPointsResult { + result: None, + status: None, + error: Some(e.to_string()), + }), + ), + } +} + +pub type GetPointResult = APIResult>; + +pub async fn get_point( + Path(name): Path, + Path(point_id): Path, + State(db): State>>, +) -> impl IntoResponse { + let conn: tokio::sync::MutexGuard = db.lock().await; + let r = store::get_point(&conn, &name, point_id).optional(); + match r { + Ok(Some(point)) => ( + axum::http::StatusCode::OK, + Json(GetPointResult { + result: Some(point), + status: Some("ok".to_string()), + error: None, + }), + ), + Ok(None) => ( + axum::http::StatusCode::NOT_FOUND, + Json(GetPointResult { + result: None, + status: None, + error: Some(format!( + "Not found: Point with id {} does not exists", + point_id + )), + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(GetPointResult { + result: None, + status: None, + error: Some(e.to_string()), + }), + ), + } +} + +#[derive(Debug, serde::Deserialize)] +pub struct Search { + pub vector: Vec, + pub limit: usize, +} + +#[derive(Debug, serde::Serialize)] +pub struct ScoredPoint { + pub id: u64, + pub vector: Vec, + pub payload: Option>, + pub score: f32, +} + +pub type SearchResult = APIResult>>; + +pub async fn search_points( + Path(name): Path, + State(db): State>>, + Json(search): Json, +) -> impl IntoResponse { + let conn = db.lock().await; + let r = store::search_points(&conn, &name, search.vector.as_slice(), search.limit).optional(); + match r { + Ok(Some(points)) => ( + axum::http::StatusCode::OK, + Json(SearchResult { + result: Some(points), + status: Some("ok".to_string()), + error: None, + }), + ), + Ok(None) => ( + axum::http::StatusCode::OK, + Json(SearchResult { + result: Some(Vec::new()), + status: Some("ok".to_string()), + error: None, + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(SearchResult { + result: None, + status: None, + error: Some(e.to_string()), + }), + ), + } +} + +#[derive(Debug, serde::Deserialize)] +pub struct DeletePoints { + pub points: Vec, +} + +pub type DeletePointsResult = APIResult; + +pub async fn delete_points( + Path(name): Path, + State(db): State>>, + Json(points): Json, +) -> impl IntoResponse { + let conn = db.lock().await; + match store::delete_points(&conn, &name, points.points) { + Ok(_) => ( + axum::http::StatusCode::OK, + Json(DeletePointsResult { + result: true, + status: Some("ok".to_string()), + error: None, + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(DeletePointsResult { + result: false, + status: None, + error: Some(e.to_string()), + }), + ), + } +} + +pub async fn delete_collection( + Path(name): Path, + State(db): State>>, +) -> impl IntoResponse { + let conn = db.lock().await; + match store::delete_collection(&conn, &name) { + Ok(_) => ( + axum::http::StatusCode::OK, + Json(DeletePointsResult { + result: true, + status: Some("ok".to_string()), + error: None, + }), + ), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(DeletePointsResult { + result: false, + status: None, + error: Some(e.to_string()), + }), + ), + } +} diff --git a/src/store.rs b/src/store.rs new file mode 100644 index 0000000..e798e39 --- /dev/null +++ b/src/store.rs @@ -0,0 +1,458 @@ +use std::{collections::HashMap, mem::size_of}; + +use rusqlite::{ffi::sqlite3_auto_extension, params, Connection}; +use sqlite_vss::{sqlite3_vector_init, sqlite3_vss_init}; + +use crate::service::{CollectionsInfo, Point, ScoredPoint}; + +pub fn init() { + unsafe { + sqlite3_auto_extension(Some(sqlite3_vector_init)); + sqlite3_auto_extension(Some(sqlite3_vss_init)); + } +} + +pub fn open(path: &str) -> rusqlite::Result { + rusqlite::Connection::open(path) +} + +pub fn create_collections(conn: &Connection, name: &str, size: usize) -> rusqlite::Result<()> { + let sql = format!( + r#" + BEGIN; + CREATE VIRTUAL TABLE IF NOT EXISTS {} USING vss0(point({})); + CREATE TABLE IF NOT EXISTS {}_payload (rowid INTEGER PRIMARY KEY, payload TEXT); + COMMIT; + "#, + name, size, name + ); + conn.execute_batch(sql.as_str()) +} + +pub fn get_collections_info(conn: &Connection, name: &str) -> rusqlite::Result { + let sql = format!( + r#" + SELECT COUNT(*) FROM {}; + "#, + name + ); + let mut stmt = conn.prepare(sql.as_str())?; + let count: u64 = stmt.query_row([], |row| row.get(0)).unwrap(); + Ok(CollectionsInfo { + points_count: count, + }) +} + +#[test] +fn test_collections() { + init(); + let conn = rusqlite::Connection::open_in_memory().unwrap(); + create_collections(&conn, "test_vss", 4).unwrap(); + let r = get_collections_info(&conn, "test_vss").unwrap(); + assert_eq!(r.points_count, 0); +} + +fn blob_to_vector(blob: &[u8]) -> Vec { + unsafe { + std::slice::from_raw_parts(blob.as_ptr() as *const f32, blob.len() / size_of::()) + .to_vec() + } +} + +fn vector_to_blob(vector: &[f32]) -> Vec { + unsafe { + std::slice::from_raw_parts( + vector.as_ptr() as *const u8, + vector.len() * size_of::(), + ) + .to_vec() + } +} + +pub fn add_point(conn: &Connection, name: &str, points: &[Point]) -> rusqlite::Result> { + let mut vector_stmt = conn.prepare(&format!( + "INSERT INTO {}(rowid,point) VALUES (?1, vector_from_raw(?2))", + name + ))?; + + let mut payload_stmt = conn.prepare(&format!( + "INSERT OR REPLACE INTO {}_payload(rowid,payload) VALUES (?1, ?2)", + name + ))?; + + let mut success_id = vec![]; + + for point in points { + let raw = vector_to_blob(&point.vector); + vector_stmt.execute(params![point.id as i64, raw])?; + + let payload = serde_json::to_string(&point.payload).unwrap(); + payload_stmt.execute(params![point.id as i64, payload])?; + + success_id.push(point.id); + } + + Ok(success_id) +} + +pub fn get_points( + conn: &rusqlite::Connection, + name: &str, + ids: Vec, +) -> rusqlite::Result> { + let ids = ids + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(","); + + let point_sql = format!( + r#" + SELECT rowid,vector_to_raw(point) FROM {} WHERE rowid in ({}); + "#, + name, ids + ); + + let payload_sql = format!( + r#" + SELECT * FROM {}_payload WHERE rowid in ({}); + "#, + name, ids + ); + + let mut point_stmt = conn.prepare(point_sql.as_str()).unwrap(); + let mut payload_stmt = conn.prepare(payload_sql.as_str()).unwrap(); + + let mut map = HashMap::new(); + + let vector_r = point_stmt.query_map(params![], |row| { + let id: u64 = row.get(0)?; + let vector_raw: Vec = row.get(1)?; + let vector: Vec = blob_to_vector(&vector_raw); + Ok((id, vector)) + })?; + + let payload_r = payload_stmt.query_map(params![], |row| { + let id: u64 = row.get(0)?; + let payload_str: String = row.get(1)?; + let payload: Option> = + serde_json::from_str(&payload_str).unwrap_or_default(); + Ok((id, payload)) + })?; + + for v in vector_r { + if let Ok((id, vector)) = v { + map.insert( + id, + Point { + id, + vector, + payload: None, + }, + ); + } + } + + for v in payload_r { + if let Ok((id, payload)) = v { + if let Some(point) = map.get_mut(&id) { + point.payload = payload; + } + } + } + + Ok(map.into_iter().map(|(_, v)| v).collect()) +} + +pub fn get_point(conn: &Connection, name: &str, id: u64) -> rusqlite::Result { + let point_sql = format!( + r#" + SELECT rowid,vector_to_raw(point) FROM {} WHERE rowid = ?1; + "#, + name + ); + + let payload_sql = format!( + r#" + SELECT * FROM {}_payload WHERE rowid = ?1; + "#, + name + ); + + let mut point_stmt = conn.prepare(point_sql.as_str())?; + let mut payload_stmt = conn.prepare(payload_sql.as_str())?; + + let vector = point_stmt.query_row(params![id], |row| { + let vector_raw: Vec = row.get(1)?; + let vector: Vec = blob_to_vector(&vector_raw); + Ok(vector) + })?; + + let payload = payload_stmt.query_row(params![id], |row| { + let payload_str: String = row.get(1)?; + let payload: Option> = + serde_json::from_str(&payload_str).unwrap_or_default(); + Ok(payload) + })?; + + Ok(Point { + id, + vector, + payload, + }) +} + +#[test] +fn test_points_base() { + use serde_json::json; + init(); + let conn = rusqlite::Connection::open_in_memory().unwrap(); + create_collections(&conn, "test_vss", 4).unwrap(); + let mut points = Vec::::new(); + { + points.push(Point { + id: 1, + vector: vec![0.05, 0.61, 0.76, 0.74], + payload: json!({"city": "Berlin"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 2, + vector: vec![0.19, 0.81, 0.75, 0.11], + payload: json!({"city": "London"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 3, + vector: vec![0.36, 0.55, 0.47, 0.94], + payload: json!({"city": "Moscow"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 4, + vector: vec![0.18, 0.01, 0.85, 0.80], + payload: json!({"city": "New York"}) + .as_object() + .map(|m| m.to_owned()), + }); + points.push(Point { + id: 5, + vector: vec![0.24, 0.18, 0.22, 0.44], + payload: json!({"city": "Beijing"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 6, + vector: vec![0.35, 0.08, 0.11, 0.44], + payload: json!({"city": "Mumbai"}).as_object().map(|m| m.to_owned()), + }); + } + let r = add_point(&conn, "test_vss", &points).unwrap(); + assert_eq!(r, vec![1, 2, 3, 4, 5, 6]); + + let mut r = get_points(&conn, "test_vss", vec![1, 2, 3]).unwrap(); + assert_eq!(r.len(), 3); + r.sort_by(|a, b| a.id.cmp(&b.id)); + assert_eq!(r[0].payload, points[0].payload); + assert_eq!(r[1].payload, points[1].payload); + assert_eq!(r[2].payload, points[2].payload); + + let r = get_point(&conn, "test_vss", 4).unwrap(); + assert_eq!(r.payload, points[3].payload); +} + +pub fn search_points( + conn: &Connection, + name: &str, + vector: &[f32], + limit: usize, +) -> rusqlite::Result> { + let sql = format!( + r#" + SELECT rowid,vector_to_raw(point),distance FROM {} WHERE vss_search(point,vector_from_raw(?1)) ORDER BY distance LIMIT ?2; + "#, + name + ); + + let mut stmt = conn.prepare(sql.as_str())?; + let vector_raw = vector_to_blob(&vector); + let points = stmt.query_map(params![vector_raw, limit], |row| { + let id: u64 = row.get(0)?; + let vector_raw: Vec = row.get(1)?; + let score: f32 = row.get(2)?; + let vector: Vec = blob_to_vector(&vector_raw); + Ok(ScoredPoint { + id, + vector, + payload: None, + score, + }) + })?; + + let mut map = HashMap::new(); + for point in points { + if let Ok(point) = point { + map.insert(point.id, point); + } + } + + let ids = map + .keys() + .map(|id| id.to_string()) + .collect::>() + .join(","); + + let payload_sql = format!( + r#" + SELECT * FROM {}_payload WHERE rowid in ({}); + "#, + name, ids + ); + let mut payload_stmt = conn.prepare(payload_sql.as_str())?; + let payload_r = payload_stmt.query_map(params![], |row| { + let id: u64 = row.get(0)?; + let payload_str: String = row.get(1)?; + let payload: Option> = + serde_json::from_str(&payload_str).unwrap_or_default(); + Ok((id, payload)) + })?; + + for v in payload_r { + if let Ok((id, payload)) = v { + if let Some(point) = map.get_mut(&id) { + point.payload = payload; + } + } + } + + Ok(map.into_iter().map(|(_, v)| v).collect()) +} + +#[test] +fn test_points_search() { + use serde_json::json; + init(); + let conn = rusqlite::Connection::open_in_memory().unwrap(); + create_collections(&conn, "test_vss", 4).unwrap(); + let mut points = Vec::::new(); + { + points.push(Point { + id: 1, + vector: vec![0.05, 0.61, 0.76, 0.74], + payload: json!({"city": "Berlin"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 2, + vector: vec![0.19, 0.81, 0.75, 0.11], + payload: json!({"city": "London"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 3, + vector: vec![0.36, 0.55, 0.47, 0.94], + payload: json!({"city": "Moscow"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 4, + vector: vec![0.18, 0.01, 0.85, 0.80], + payload: json!({"city": "New York"}) + .as_object() + .map(|m| m.to_owned()), + }); + points.push(Point { + id: 5, + vector: vec![0.24, 0.18, 0.22, 0.44], + payload: json!({"city": "Beijing"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 6, + vector: vec![0.35, 0.08, 0.11, 0.44], + payload: json!({"city": "Mumbai"}).as_object().map(|m| m.to_owned()), + }); + } + let r = add_point(&conn, "test_vss", &points).unwrap(); + assert_eq!(r, vec![1, 2, 3, 4, 5, 6]); + + let q = vec![0.2, 0.1, 0.9, 0.7]; + let r = search_points(&conn, "test_vss", &q, 2).unwrap(); + assert_eq!(r.len(), 2); + assert_eq!(r[0].id, 4); + assert_eq!(r[1].id, 1); +} + +pub fn delete_points(conn: &Connection, name: &str, ids: Vec) -> rusqlite::Result<()> { + let ids = ids + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(","); + + let sql = format!( + r#" + BEGIN; + DELETE FROM {} WHERE rowid in ({}); + DELETE FROM {}_payload WHERE rowid in ({}); + COMMIT; + "#, + name, ids, name, ids + ); + conn.execute_batch(sql.as_str()) +} + +#[test] +fn test_points_delete() { + use serde_json::json; + init(); + let conn = rusqlite::Connection::open_in_memory().unwrap(); + create_collections(&conn, "test_vss", 4).unwrap(); + let mut points = Vec::::new(); + { + points.push(Point { + id: 1, + vector: vec![0.05, 0.61, 0.76, 0.74], + payload: json!({"city": "Berlin"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 2, + vector: vec![0.19, 0.81, 0.75, 0.11], + payload: json!({"city": "London"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 3, + vector: vec![0.36, 0.55, 0.47, 0.94], + payload: json!({"city": "Moscow"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 4, + vector: vec![0.18, 0.01, 0.85, 0.80], + payload: json!({"city": "New York"}) + .as_object() + .map(|m| m.to_owned()), + }); + points.push(Point { + id: 5, + vector: vec![0.24, 0.18, 0.22, 0.44], + payload: json!({"city": "Beijing"}).as_object().map(|m| m.to_owned()), + }); + points.push(Point { + id: 6, + vector: vec![0.35, 0.08, 0.11, 0.44], + payload: json!({"city": "Mumbai"}).as_object().map(|m| m.to_owned()), + }); + } + let r = add_point(&conn, "test_vss", &points).unwrap(); + assert_eq!(r, vec![1, 2, 3, 4, 5, 6]); + + delete_points(&conn, "test_vss", vec![1, 2, 3, 4]).unwrap(); + + let r = get_points(&conn, "test_vss", vec![1, 2, 3, 4]).unwrap(); + assert_eq!(r.len(), 0); +} + +pub fn delete_collection(conn: &Connection, name: &str) -> rusqlite::Result<()> { + let sql = format!( + r#" + BEGIN; + DROP TABLE IF EXISTS {}; + DROP TABLE IF EXISTS {}_payload; + COMMIT; + "#, + name, name + ); + conn.execute_batch(sql.as_str()) +}