-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 17f659c
Showing
5 changed files
with
841 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[package] | ||
name = "rusqlite-vss" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
# rusqlite = { version = "0.31.0", features=["bundled"] } | ||
rusqlite = { version = "0.29.0", features = ["bundled"] } | ||
sqlite-vss = { version = "0.1.2", features = ["download-libs"] } | ||
tokio = { version = "1.37.0", features = ["full"] } | ||
axum = "0.7.5" | ||
anyhow = "1.0.86" | ||
serde = { version = "1", features = ["derive"] } | ||
serde_json = "1.0.117" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
use std::sync::Arc; | ||
|
||
use axum::{ | ||
routing::{delete, get, post, put}, | ||
Router, | ||
}; | ||
use tokio::sync::Mutex; | ||
|
||
pub mod service; | ||
pub mod store; | ||
|
||
#[tokio::main] | ||
async fn main() -> anyhow::Result<()> { | ||
store::init(); | ||
|
||
let addr = std::env::var("LISTEN_ADDR").unwrap_or("0.0.0.0:3000".to_string()); | ||
|
||
let db = store::open("store.sqlite")?; | ||
|
||
let app = Router::new() | ||
.route("/collections/:name", put(service::create_collections)) | ||
.route("/collections/:name", get(service::get_collections_info)) | ||
.route("/collections/:name", delete(service::delete_collection)) | ||
.route("/collections/:name/points", put(service::add_points)) | ||
.route( | ||
"/collections/:name/points/delete", | ||
post(service::delete_points), | ||
) | ||
.route( | ||
"/collections/:name/points/:point_id", | ||
get(service::get_point), | ||
) | ||
.route("/collections/:name/points", post(service::get_points)) | ||
.with_state(Arc::new(Mutex::new(db))); | ||
|
||
let listener = tokio::net::TcpListener::bind(addr).await?; | ||
axum::serve(listener, app).await?; | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,328 @@ | ||
use std::sync::Arc; | ||
|
||
use axum::{ | ||
extract::{Path, State}, | ||
response::IntoResponse, | ||
Json, | ||
}; | ||
|
||
use rusqlite::OptionalExtension; | ||
use tokio::sync::Mutex; | ||
|
||
use crate::store; | ||
|
||
#[derive(Debug, serde::Serialize)] | ||
pub struct APIResult<T> { | ||
pub result: T, | ||
pub status: Option<String>, | ||
pub error: Option<String>, | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct CreateConllections { | ||
pub vectors: CreateConllectionsVectors, | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct CreateConllectionsVectors { | ||
pub size: usize, | ||
} | ||
|
||
pub type CreateConllectionsResult = APIResult<bool>; | ||
|
||
pub async fn create_collections( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
Json(create_conllections): Json<CreateConllections>, | ||
) -> impl IntoResponse { | ||
let conn = db.lock().await; | ||
if let Err(e) = store::create_collections(&conn, &name, create_conllections.vectors.size) { | ||
return ( | ||
axum::http::StatusCode::CONFLICT, | ||
Json(CreateConllectionsResult { | ||
result: false, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
); | ||
} else { | ||
return ( | ||
axum::http::StatusCode::OK, | ||
Json(CreateConllectionsResult { | ||
result: true, | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
); | ||
} | ||
} | ||
|
||
#[derive(Debug, serde::Serialize)] | ||
pub struct CollectionsInfo { | ||
pub points_count: u64, | ||
} | ||
|
||
pub type GetCollectionsResult = APIResult<CollectionsInfo>; | ||
|
||
pub async fn get_collections_info( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
) -> impl IntoResponse { | ||
let conn = db.lock().await; | ||
match store::get_collections_info(&conn, &name) { | ||
Ok(info) => ( | ||
axum::http::StatusCode::OK, | ||
Json(GetCollectionsResult { | ||
result: CollectionsInfo { | ||
points_count: info.points_count, | ||
}, | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(GetCollectionsResult { | ||
result: CollectionsInfo { points_count: 0 }, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct AddPoints { | ||
pub points: Vec<Point>, | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize, serde::Serialize)] | ||
pub struct Point { | ||
pub id: u64, | ||
pub vector: Vec<f32>, | ||
pub payload: Option<serde_json::Map<String, serde_json::Value>>, | ||
} | ||
|
||
pub type AddPointsResult = APIResult<Option<Vec<u64>>>; | ||
|
||
pub async fn add_points( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
Json(points): Json<AddPoints>, | ||
) -> impl IntoResponse { | ||
{ | ||
let conn = db.lock().await; | ||
match store::add_point(&conn, &name, &points.points) { | ||
Ok(success_id) => ( | ||
axum::http::StatusCode::OK, | ||
Json(AddPointsResult { | ||
result: Some(success_id), | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(AddPointsResult { | ||
result: None, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct GetPoints { | ||
ids: Vec<u64>, | ||
} | ||
|
||
pub type GetPointsResult = APIResult<Option<Vec<Point>>>; | ||
|
||
pub async fn get_points( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
Json(ids): Json<GetPoints>, | ||
) -> impl IntoResponse { | ||
let r = { | ||
let conn = db.lock().await; | ||
store::get_points(&conn, &name, ids.ids).optional() | ||
}; | ||
|
||
match r { | ||
Ok(Some(points)) => ( | ||
axum::http::StatusCode::OK, | ||
Json(GetPointsResult { | ||
result: Some(points), | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Ok(None) => ( | ||
axum::http::StatusCode::OK, | ||
Json(GetPointsResult { | ||
result: Some(Vec::new()), | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(GetPointsResult { | ||
result: None, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} | ||
|
||
pub type GetPointResult = APIResult<Option<Point>>; | ||
|
||
pub async fn get_point( | ||
Path(name): Path<String>, | ||
Path(point_id): Path<u64>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
) -> impl IntoResponse { | ||
let conn: tokio::sync::MutexGuard<rusqlite::Connection> = db.lock().await; | ||
let r = store::get_point(&conn, &name, point_id).optional(); | ||
match r { | ||
Ok(Some(point)) => ( | ||
axum::http::StatusCode::OK, | ||
Json(GetPointResult { | ||
result: Some(point), | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Ok(None) => ( | ||
axum::http::StatusCode::NOT_FOUND, | ||
Json(GetPointResult { | ||
result: None, | ||
status: None, | ||
error: Some(format!( | ||
"Not found: Point with id {} does not exists", | ||
point_id | ||
)), | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(GetPointResult { | ||
result: None, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct Search { | ||
pub vector: Vec<f32>, | ||
pub limit: usize, | ||
} | ||
|
||
#[derive(Debug, serde::Serialize)] | ||
pub struct ScoredPoint { | ||
pub id: u64, | ||
pub vector: Vec<f32>, | ||
pub payload: Option<serde_json::Map<String, serde_json::Value>>, | ||
pub score: f32, | ||
} | ||
|
||
pub type SearchResult = APIResult<Option<Vec<ScoredPoint>>>; | ||
|
||
pub async fn search_points( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
Json(search): Json<Search>, | ||
) -> impl IntoResponse { | ||
let conn = db.lock().await; | ||
let r = store::search_points(&conn, &name, search.vector.as_slice(), search.limit).optional(); | ||
match r { | ||
Ok(Some(points)) => ( | ||
axum::http::StatusCode::OK, | ||
Json(SearchResult { | ||
result: Some(points), | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Ok(None) => ( | ||
axum::http::StatusCode::OK, | ||
Json(SearchResult { | ||
result: Some(Vec::new()), | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(SearchResult { | ||
result: None, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct DeletePoints { | ||
pub points: Vec<u64>, | ||
} | ||
|
||
pub type DeletePointsResult = APIResult<bool>; | ||
|
||
pub async fn delete_points( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
Json(points): Json<DeletePoints>, | ||
) -> impl IntoResponse { | ||
let conn = db.lock().await; | ||
match store::delete_points(&conn, &name, points.points) { | ||
Ok(_) => ( | ||
axum::http::StatusCode::OK, | ||
Json(DeletePointsResult { | ||
result: true, | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(DeletePointsResult { | ||
result: false, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} | ||
|
||
pub async fn delete_collection( | ||
Path(name): Path<String>, | ||
State(db): State<Arc<Mutex<rusqlite::Connection>>>, | ||
) -> impl IntoResponse { | ||
let conn = db.lock().await; | ||
match store::delete_collection(&conn, &name) { | ||
Ok(_) => ( | ||
axum::http::StatusCode::OK, | ||
Json(DeletePointsResult { | ||
result: true, | ||
status: Some("ok".to_string()), | ||
error: None, | ||
}), | ||
), | ||
Err(e) => ( | ||
axum::http::StatusCode::INTERNAL_SERVER_ERROR, | ||
Json(DeletePointsResult { | ||
result: false, | ||
status: None, | ||
error: Some(e.to_string()), | ||
}), | ||
), | ||
} | ||
} |
Oops, something went wrong.