Skip to content

Commit

Permalink
Added PgVector integration
Browse files Browse the repository at this point in the history
Signed-off-by: chandr-andr (Kiselev Aleksandr) <[email protected]>
  • Loading branch information
chandr-andr committed Nov 18, 2024
1 parent 335e591 commit 6c903e5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 0 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ itertools = "0.12.1"
openssl-src = "300.2.2"
openssl-sys = "0.9.102"
pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git", branch = "psqlpy" }
pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = "psqlpy", features = [
"postgres",
] }
13 changes: 13 additions & 0 deletions python/psqlpy/_internal/extra_types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,16 @@ class IntervalArray:
### Parameters:
- `inner`: inner value, sequence of timedelta values.
"""

class PgVector:
"""Represent VECTOR in PostgreSQL."""

def __init__(
self: Self,
inner: typing.Sequence[float | int],
) -> None:
"""Create new instance of PgVector.
### Parameters:
- `inner`: inner value, sequence of float or int values.
"""
2 changes: 2 additions & 0 deletions python/psqlpy/extra_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MoneyArray,
NumericArray,
PathArray,
PgVector,
PointArray,
PyBox,
PyCircle,
Expand Down Expand Up @@ -98,4 +99,5 @@
"LsegArray",
"CircleArray",
"IntervalArray",
"PgVector",
]
20 changes: 20 additions & 0 deletions src/extra_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ use crate::{
},
};

#[pyclass]
#[derive(Clone)]
pub struct PgVector(Vec<f32>);

#[pymethods]
impl PgVector {
#[new]
fn new(vector: Vec<f32>) -> Self {
Self(vector)
}
}

impl PgVector {
#[must_use]
pub fn inner_value(self) -> Vec<f32> {
self.0
}
}

macro_rules! build_python_type {
($st_name:ident, $rust_type:ty) => {
#[pyclass]
Expand Down Expand Up @@ -412,5 +431,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
pymod.add_class::<LsegArray>()?;
pymod.add_class::<CircleArray>()?;
pymod.add_class::<IntervalArray>()?;
pymod.add_class::<PgVector>()?;
Ok(())
}
12 changes: 12 additions & 0 deletions src/value_converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::{
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
extra_types,
};
use pgvector::Vector as PgVector;
use postgres_array::{array::Array, Dimension};

static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
Expand Down Expand Up @@ -268,6 +269,8 @@ pub enum PythonDTO {
PyLsegArray(Array<PythonDTO>),
PyCircleArray(Array<PythonDTO>),
PyIntervalArray(Array<PythonDTO>),
// PgVector
PyPgVector(Vec<f32>),
}

impl ToPyObject for PythonDTO {
Expand Down Expand Up @@ -594,6 +597,9 @@ impl ToSql for PythonDTO {
PythonDTO::PyIntervalArray(array) => {
array.to_sql(&Type::INTERVAL_ARRAY, out)?;
}
PythonDTO::PyPgVector(vector) => {
<PgVector as ToSql>::to_sql(&PgVector::from(vector.clone()), ty, out)?;
}
}

if return_is_null_true {
Expand Down Expand Up @@ -1139,6 +1145,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
._convert_to_python_dto();
}

if parameter.is_instance_of::<extra_types::PgVector>() {
return Ok(PythonDTO::PyPgVector(
parameter.extract::<extra_types::PgVector>()?.inner_value(),
));
}

if let Ok(id_address) = parameter.extract::<IpAddr>() {
return Ok(PythonDTO::PyIpAddress(id_address));
}
Expand Down

0 comments on commit 6c903e5

Please sign in to comment.