Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aggregate function for vector. #463

Merged
merged 6 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/base/src/vector/vecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ impl<'a> Vecf32Borrowed<'a> {
pub fn slice(&self) -> &[F32] {
self.0
}
pub fn l2_norm(&self) -> F32 {
dot(self.slice(), self.slice()).sqrt()
}
}

impl<'a> VectorBorrowed for Vecf32Borrowed<'a> {
Expand Down
292 changes: 292 additions & 0 deletions src/datatype/functions_vecf32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
#![allow(unused_lifetimes)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#![allow(clippy::extra_unused_lifetimes)]
use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output};
use crate::error::*;
use base::scalar::*;
use base::vector::*;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::{FromDatum, IntoDatum};
use std::alloc::Layout;
use std::ffi::{CStr, CString};
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;

#[repr(C, align(8))]
pub struct AccumulateStateHeader {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be better if it's more generic so that we use reuse it for more usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only AccumulateState for vecf32 and AccumulateState for vecf16 layout can be same? Just adding a generic type seems won't be so helpful?

varlena: u32,
dims: u16,
count: u64,
phantom: [F32; 0],
}

impl AccumulateStateHeader {
fn varlena(size: usize) -> u32 {
(size << 2) as u32
}
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<AccumulateStateHeader>();
let layout_beta = Layout::array::<F32>(len).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
pub fn dims(&self) -> usize {
self.dims as usize
}
pub fn count(&self) -> u64 {
self.count
}
pub fn slice(&self) -> &[F32] {
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) }
}
pub fn slice_mut(&mut self) -> &mut [F32] {
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.dims as usize) }
}
}

pub enum AccumulateState<'a> {
Owned(NonNull<AccumulateStateHeader>),
Borrowed(&'a mut AccumulateStateHeader),
}

impl<'a> AccumulateState<'a> {
unsafe fn new(p: NonNull<AccumulateStateHeader>) -> Self {
// datum maybe toasted, try to detoast it
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap()
};
if p != q {
AccumulateState::Owned(q)
} else {
unsafe { AccumulateState::Borrowed(&mut *p.as_ptr()) }
}
}

pub fn new_with_slice(count: u64, slice: &[F32]) -> Self {
let dims = slice.len();
let layout = AccumulateStateHeader::layout(dims);
unsafe {
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut AccumulateStateHeader;
std::ptr::addr_of_mut!((*ptr).varlena)
.write(AccumulateStateHeader::varlena(layout.size()));
std::ptr::addr_of_mut!((*ptr).dims).write(dims as u16);
std::ptr::addr_of_mut!((*ptr).count).write(count);
if dims > 0 {
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), dims);
}
AccumulateState::Owned(NonNull::new(ptr).unwrap())
}
}

pub fn into_raw(self) -> *mut AccumulateStateHeader {
let result = match self {
AccumulateState::Owned(p) => p.as_ptr(),
AccumulateState::Borrowed(ref p) => {
*p as *const AccumulateStateHeader as *mut AccumulateStateHeader
}
};
std::mem::forget(self);
result
}
}

impl Deref for AccumulateState<'_> {
type Target = AccumulateStateHeader;

fn deref(&self) -> &Self::Target {
match self {
AccumulateState::Owned(p) => unsafe { p.as_ref() },
AccumulateState::Borrowed(p) => p,
}
}
}

impl DerefMut for AccumulateState<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
AccumulateState::Owned(p) => unsafe { p.as_mut() },
AccumulateState::Borrowed(p) => p,
}
}
}

impl Drop for AccumulateState<'_> {
fn drop(&mut self) {
match self {
AccumulateState::Owned(p) => unsafe {
pgrx::pg_sys::pfree(p.as_ptr().cast());
},
AccumulateState::Borrowed(_) => {}
}
}
}

impl FromDatum for AccumulateState<'_> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typmod: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<AccumulateStateHeader>()).unwrap();
unsafe { Some(AccumulateState::new(ptr)) }
}
}
}

impl IntoDatum for AccumulateState<'_> {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
}

fn type_oid() -> Oid {
let namespace = pgrx::pg_catalog::PgNamespace::search_namespacename(c"vectors").unwrap();
let namespace = namespace.get().expect("pgvecto.rs is not installed.");
let t = pgrx::pg_catalog::PgType::search_typenamensp(
c"vector_accumulate_state ",
namespace.oid(),
)
.unwrap();
let t = t.get().expect("pg_catalog is broken.");
t.oid()
}
}

unsafe impl SqlTranslatable for AccumulateState<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector_accumulate_state ")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from(
"vector_accumulate_state ",
))))
}
}

fn parse_accumulate_state(input: &[u8]) -> Result<(u64, Vec<F32>), String> {
use crate::utils::parse::parse_vector;
let hint = "Invalid input format for accumulatestate, using \'bigint, array \' like \'1, [1]\'";
let (count, slice) = input.split_once(|&c| c == b',').ok_or(hint)?;
let count = std::str::from_utf8(count)
.map_err(|e| e.to_string() + "\n" + hint)?
.parse::<u64>()
.map_err(|e| e.to_string() + "\n" + hint)?;
let v = parse_vector(slice, 0, |s| s.parse().ok());
match v {
Err(e) => Err(e.to_string() + "\n" + hint),
Ok(vector) => Ok((count, vector)),
}
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_accumulate_state_in(input: &CStr, _oid: Oid, _typmod: i32) -> AccumulateState<'_> {
// parse one bigint and a vector of f32, split with a comma
let res = parse_accumulate_state(input.to_bytes());
match res {
Err(e) => {
bad_literal(&e.to_string());
}
Ok((count, vector)) => AccumulateState::new_with_slice(count, &vector),
}
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_accumulate_state_out(state: AccumulateState<'_>) -> CString {
let mut buffer = String::new();
buffer.push_str(format!("{}, ", state.count()).as_str());
buffer.push('[');
if let Some(&x) = state.slice().first() {
buffer.push_str(format!("{}", x).as_str());
}
for &x in state.slice().iter().skip(1) {
buffer.push_str(format!(", {}", x).as_str());
}
buffer.push(']');
CString::new(buffer).unwrap()
}

/// accumulate intermediate state for vector average
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_accum<'a>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meet unused-lifetimes warning here. I don't configure out the reason here. As far as I know , this warning only occurs with #[pgrx::pg_extern].
The 235th line is the same.
cc @usamoi

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The macro expands to

pub unsafe extern "C" fn _vectors_bvecf32_send_wrapper<'a>(
    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
) -> ::pgrx::pg_sys::Datum

The lifetime parameter is redundant here. You could use allow to suppress this warning before pgrx fixes it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just discovered this problem too. After add #![allow(unused_lifetimes)] #![allow(clippy::extra_unused_lifetimes)] in file header, it works. Thanks!

mut state: AccumulateState<'a>,
value: Vecf32Input<'_>,
) -> AccumulateState<'a> {
let count = state.count();
match count {
// if the state is empty, copy the input vector
0 => AccumulateState::new_with_slice(1, value.iter().as_slice()),
_ => {
let dims = state.dims();
let value_dims = value.dims();
check_matched_dims(dims, value_dims);
let sum = state.slice_mut();
// accumulate the input vector
for (x, y) in sum.iter_mut().zip(value.iter()) {
*x += *y;
}
// increase the count
state.count += 1;
state
}
}
}

/// combine two intermediate states for vector average
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_combine<'a>(
mut state1: AccumulateState<'a>,
state2: AccumulateState<'a>,
) -> AccumulateState<'a> {
let count1 = state1.count();
let count2 = state2.count();
if count1 == 0 {
state2
} else if count2 == 0 {
state1
} else {
let dims1 = state1.dims();
let dims2 = state2.dims();
check_matched_dims(dims1, dims2);
state1.count += count2;
let sum1 = state1.slice_mut();
let sum2 = state2.slice();
for (x, y) in sum1.iter_mut().zip(sum2.iter()) {
*x += *y;
}
state1
}
}

/// finalize the intermediate state for vector average
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_final(state: AccumulateState<'_>) -> Option<Vecf32Output> {
let count = state.count();
if count == 0 {
// return NULL if all inputs are NULL
return None;
}
let sum = state
.slice()
.iter()
.map(|x| *x / F32(count as f32))
.collect::<Vec<_>>();
Some(Vecf32Output::new(
Vecf32Borrowed::new_checked(&sum).unwrap(),
))
}

/// Get the dimensions of a vector.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 {
vector.dims() as i32
}

/// Calculate the l2 norm of a vector.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_norm(vector: Vecf32Input<'_>) -> f32 {
vector.for_borrow().l2_norm().to_f32()
}
1 change: 1 addition & 0 deletions src/datatype/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod binary_veci8;
pub mod casts;
pub mod functions_bvecf32;
pub mod functions_svecf32;
pub mod functions_vecf32;
pub mod functions_veci8;
pub mod memory_bvecf32;
pub mod memory_svecf32;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! Provides an easy-to-use extension for vector similarity search.
#![feature(alloc_error_hook)]
#![feature(slice_split_once)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::single_match)]
#![allow(clippy::too_many_arguments)]
Expand Down
1 change: 1 addition & 0 deletions src/sql/bootstrap.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ CREATE TYPE svector;
CREATE TYPE bvector;
CREATE TYPE veci8;
CREATE TYPE vector_index_stat;
CREATE TYPE vector_accumulate_state;

-- bootstrap end
32 changes: 32 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ CREATE TYPE vector_index_stat AS (
idx_options TEXT
);

CREATE TYPE vector_accumulate_state (
INPUT = _vectors_accumulate_state_in,
OUTPUT = _vectors_accumulate_state_out,
STORAGE = EXTERNAL,
INTERNALLENGTH = VARIABLE,
ALIGNMENT = double
);

-- List of operators

CREATE OPERATOR + (
Expand Down Expand Up @@ -593,6 +601,30 @@ $$;
CREATE FUNCTION alter_vector_index("index" OID, "key" TEXT, "value" TEXT) RETURNS void
STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_alter_vector_index_wrapper';

CREATE FUNCTION vector_dims("v" vector) RETURNS INT
STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_dims_wrapper';

CREATE FUNCTION vector_norm("v" vector) RETURNS real
STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_norm_wrapper';

-- List of aggregates

CREATE AGGREGATE avg(vector) (
SFUNC = _vectors_vector_accum,
STYPE = vector_accumulate_state,
COMBINEFUNC = _vectors_vector_combine,
FINALFUNC = _vectors_vector_final,
INITCOND = '0, []',
PARALLEL = SAFE
);

CREATE AGGREGATE sum(vector) (
SFUNC = _vectors_vecf32_operator_add,
STYPE = vector,
COMBINEFUNC = _vectors_vecf32_operator_add,
PARALLEL = SAFE
);

-- List of casts

CREATE CAST (real[] AS vector)
Expand Down
Loading
Loading