Skip to content

Commit 48ac1be

Browse files
my-vegetable-has-explodedJinweiOS
authored andcommitted
feat: support aggregate function for vector. (tensorchord#463)
* implement vector avg. Signed-off-by: my-vegetable-has-exploded <[email protected]> * Implement sum、 vector_dims()、 vector_norm() Signed-off-by: my-vegetable-has-exploded <[email protected]> * update AccumulateState Signed-off-by: my-vegetable-has-exploded <[email protected]> * fix norm. Signed-off-by: my-vegetable-has-exploded <[email protected]> * add l2_norm for VectorBorrowed. Signed-off-by: my-vegetable-has-exploded <[email protected]> --------- Signed-off-by: my-vegetable-has-exploded <[email protected]> Signed-off-by: jinweios <[email protected]>
1 parent 3848413 commit 48ac1be

File tree

7 files changed

+434
-0
lines changed

7 files changed

+434
-0
lines changed

crates/base/src/vector/vecf32.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ impl<'a> Vecf32Borrowed<'a> {
8282
pub fn slice(&self) -> &[F32] {
8383
self.0
8484
}
85+
pub fn l2_norm(&self) -> F32 {
86+
dot(self.slice(), self.slice()).sqrt()
87+
}
8588
}
8689

8790
impl<'a> VectorBorrowed for Vecf32Borrowed<'a> {

src/datatype/functions_vecf32.rs

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
#![allow(unused_lifetimes)]
2+
#![allow(clippy::extra_unused_lifetimes)]
3+
use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output};
4+
use crate::error::*;
5+
use base::scalar::*;
6+
use base::vector::*;
7+
use pgrx::pg_sys::Datum;
8+
use pgrx::pg_sys::Oid;
9+
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
10+
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
11+
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
12+
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
13+
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
14+
use pgrx::{FromDatum, IntoDatum};
15+
use std::alloc::Layout;
16+
use std::ffi::{CStr, CString};
17+
use std::ops::{Deref, DerefMut};
18+
use std::ptr::NonNull;
19+
20+
#[repr(C, align(8))]
21+
pub struct AccumulateStateHeader {
22+
varlena: u32,
23+
dims: u16,
24+
count: u64,
25+
phantom: [F32; 0],
26+
}
27+
28+
impl AccumulateStateHeader {
29+
fn varlena(size: usize) -> u32 {
30+
(size << 2) as u32
31+
}
32+
fn layout(len: usize) -> Layout {
33+
u16::try_from(len).expect("Vector is too large.");
34+
let layout_alpha = Layout::new::<AccumulateStateHeader>();
35+
let layout_beta = Layout::array::<F32>(len).unwrap();
36+
let layout = layout_alpha.extend(layout_beta).unwrap().0;
37+
layout.pad_to_align()
38+
}
39+
pub fn dims(&self) -> usize {
40+
self.dims as usize
41+
}
42+
pub fn count(&self) -> u64 {
43+
self.count
44+
}
45+
pub fn slice(&self) -> &[F32] {
46+
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) }
47+
}
48+
pub fn slice_mut(&mut self) -> &mut [F32] {
49+
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.dims as usize) }
50+
}
51+
}
52+
53+
pub enum AccumulateState<'a> {
54+
Owned(NonNull<AccumulateStateHeader>),
55+
Borrowed(&'a mut AccumulateStateHeader),
56+
}
57+
58+
impl<'a> AccumulateState<'a> {
59+
unsafe fn new(p: NonNull<AccumulateStateHeader>) -> Self {
60+
// datum maybe toasted, try to detoast it
61+
let q = unsafe {
62+
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap()
63+
};
64+
if p != q {
65+
AccumulateState::Owned(q)
66+
} else {
67+
unsafe { AccumulateState::Borrowed(&mut *p.as_ptr()) }
68+
}
69+
}
70+
71+
pub fn new_with_slice(count: u64, slice: &[F32]) -> Self {
72+
let dims = slice.len();
73+
let layout = AccumulateStateHeader::layout(dims);
74+
unsafe {
75+
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut AccumulateStateHeader;
76+
std::ptr::addr_of_mut!((*ptr).varlena)
77+
.write(AccumulateStateHeader::varlena(layout.size()));
78+
std::ptr::addr_of_mut!((*ptr).dims).write(dims as u16);
79+
std::ptr::addr_of_mut!((*ptr).count).write(count);
80+
if dims > 0 {
81+
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), dims);
82+
}
83+
AccumulateState::Owned(NonNull::new(ptr).unwrap())
84+
}
85+
}
86+
87+
pub fn into_raw(self) -> *mut AccumulateStateHeader {
88+
let result = match self {
89+
AccumulateState::Owned(p) => p.as_ptr(),
90+
AccumulateState::Borrowed(ref p) => {
91+
*p as *const AccumulateStateHeader as *mut AccumulateStateHeader
92+
}
93+
};
94+
std::mem::forget(self);
95+
result
96+
}
97+
}
98+
99+
impl Deref for AccumulateState<'_> {
100+
type Target = AccumulateStateHeader;
101+
102+
fn deref(&self) -> &Self::Target {
103+
match self {
104+
AccumulateState::Owned(p) => unsafe { p.as_ref() },
105+
AccumulateState::Borrowed(p) => p,
106+
}
107+
}
108+
}
109+
110+
impl DerefMut for AccumulateState<'_> {
111+
fn deref_mut(&mut self) -> &mut Self::Target {
112+
match self {
113+
AccumulateState::Owned(p) => unsafe { p.as_mut() },
114+
AccumulateState::Borrowed(p) => p,
115+
}
116+
}
117+
}
118+
119+
impl Drop for AccumulateState<'_> {
120+
fn drop(&mut self) {
121+
match self {
122+
AccumulateState::Owned(p) => unsafe {
123+
pgrx::pg_sys::pfree(p.as_ptr().cast());
124+
},
125+
AccumulateState::Borrowed(_) => {}
126+
}
127+
}
128+
}
129+
130+
impl FromDatum for AccumulateState<'_> {
131+
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typmod: Oid) -> Option<Self> {
132+
if is_null {
133+
None
134+
} else {
135+
let ptr = NonNull::new(datum.cast_mut_ptr::<AccumulateStateHeader>()).unwrap();
136+
unsafe { Some(AccumulateState::new(ptr)) }
137+
}
138+
}
139+
}
140+
141+
impl IntoDatum for AccumulateState<'_> {
142+
fn into_datum(self) -> Option<Datum> {
143+
Some(Datum::from(self.into_raw() as *mut ()))
144+
}
145+
146+
fn type_oid() -> Oid {
147+
let namespace = pgrx::pg_catalog::PgNamespace::search_namespacename(c"vectors").unwrap();
148+
let namespace = namespace.get().expect("pgvecto.rs is not installed.");
149+
let t = pgrx::pg_catalog::PgType::search_typenamensp(
150+
c"vector_accumulate_state ",
151+
namespace.oid(),
152+
)
153+
.unwrap();
154+
let t = t.get().expect("pg_catalog is broken.");
155+
t.oid()
156+
}
157+
}
158+
159+
unsafe impl SqlTranslatable for AccumulateState<'_> {
160+
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
161+
Ok(SqlMapping::As(String::from("vector_accumulate_state ")))
162+
}
163+
fn return_sql() -> Result<Returns, ReturnsError> {
164+
Ok(Returns::One(SqlMapping::As(String::from(
165+
"vector_accumulate_state ",
166+
))))
167+
}
168+
}
169+
170+
fn parse_accumulate_state(input: &[u8]) -> Result<(u64, Vec<F32>), String> {
171+
use crate::utils::parse::parse_vector;
172+
let hint = "Invalid input format for accumulatestate, using \'bigint, array \' like \'1, [1]\'";
173+
let (count, slice) = input.split_once(|&c| c == b',').ok_or(hint)?;
174+
let count = std::str::from_utf8(count)
175+
.map_err(|e| e.to_string() + "\n" + hint)?
176+
.parse::<u64>()
177+
.map_err(|e| e.to_string() + "\n" + hint)?;
178+
let v = parse_vector(slice, 0, |s| s.parse().ok());
179+
match v {
180+
Err(e) => Err(e.to_string() + "\n" + hint),
181+
Ok(vector) => Ok((count, vector)),
182+
}
183+
}
184+
185+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
186+
fn _vectors_accumulate_state_in(input: &CStr, _oid: Oid, _typmod: i32) -> AccumulateState<'_> {
187+
// parse one bigint and a vector of f32, split with a comma
188+
let res = parse_accumulate_state(input.to_bytes());
189+
match res {
190+
Err(e) => {
191+
bad_literal(&e.to_string());
192+
}
193+
Ok((count, vector)) => AccumulateState::new_with_slice(count, &vector),
194+
}
195+
}
196+
197+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
198+
fn _vectors_accumulate_state_out(state: AccumulateState<'_>) -> CString {
199+
let mut buffer = String::new();
200+
buffer.push_str(format!("{}, ", state.count()).as_str());
201+
buffer.push('[');
202+
if let Some(&x) = state.slice().first() {
203+
buffer.push_str(format!("{}", x).as_str());
204+
}
205+
for &x in state.slice().iter().skip(1) {
206+
buffer.push_str(format!(", {}", x).as_str());
207+
}
208+
buffer.push(']');
209+
CString::new(buffer).unwrap()
210+
}
211+
212+
/// accumulate intermediate state for vector average
213+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
214+
fn _vectors_vector_accum<'a>(
215+
mut state: AccumulateState<'a>,
216+
value: Vecf32Input<'_>,
217+
) -> AccumulateState<'a> {
218+
let count = state.count();
219+
match count {
220+
// if the state is empty, copy the input vector
221+
0 => AccumulateState::new_with_slice(1, value.iter().as_slice()),
222+
_ => {
223+
let dims = state.dims();
224+
let value_dims = value.dims();
225+
check_matched_dims(dims, value_dims);
226+
let sum = state.slice_mut();
227+
// accumulate the input vector
228+
for (x, y) in sum.iter_mut().zip(value.iter()) {
229+
*x += *y;
230+
}
231+
// increase the count
232+
state.count += 1;
233+
state
234+
}
235+
}
236+
}
237+
238+
/// combine two intermediate states for vector average
239+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
240+
fn _vectors_vector_combine<'a>(
241+
mut state1: AccumulateState<'a>,
242+
state2: AccumulateState<'a>,
243+
) -> AccumulateState<'a> {
244+
let count1 = state1.count();
245+
let count2 = state2.count();
246+
if count1 == 0 {
247+
state2
248+
} else if count2 == 0 {
249+
state1
250+
} else {
251+
let dims1 = state1.dims();
252+
let dims2 = state2.dims();
253+
check_matched_dims(dims1, dims2);
254+
state1.count += count2;
255+
let sum1 = state1.slice_mut();
256+
let sum2 = state2.slice();
257+
for (x, y) in sum1.iter_mut().zip(sum2.iter()) {
258+
*x += *y;
259+
}
260+
state1
261+
}
262+
}
263+
264+
/// finalize the intermediate state for vector average
265+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
266+
fn _vectors_vector_final(state: AccumulateState<'_>) -> Option<Vecf32Output> {
267+
let count = state.count();
268+
if count == 0 {
269+
// return NULL if all inputs are NULL
270+
return None;
271+
}
272+
let sum = state
273+
.slice()
274+
.iter()
275+
.map(|x| *x / F32(count as f32))
276+
.collect::<Vec<_>>();
277+
Some(Vecf32Output::new(
278+
Vecf32Borrowed::new_checked(&sum).unwrap(),
279+
))
280+
}
281+
282+
/// Get the dimensions of a vector.
283+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
284+
fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 {
285+
vector.dims() as i32
286+
}
287+
288+
/// Calculate the l2 norm of a vector.
289+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
290+
fn _vectors_vector_norm(vector: Vecf32Input<'_>) -> f32 {
291+
vector.for_borrow().l2_norm().to_f32()
292+
}

src/datatype/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub mod binary_veci8;
77
pub mod casts;
88
pub mod functions_bvecf32;
99
pub mod functions_svecf32;
10+
pub mod functions_vecf32;
1011
pub mod functions_veci8;
1112
pub mod memory_bvecf32;
1213
pub mod memory_svecf32;

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//!
33
//! Provides an easy-to-use extension for vector similarity search.
44
#![feature(alloc_error_hook)]
5+
#![feature(slice_split_once)]
56
#![allow(clippy::needless_range_loop)]
67
#![allow(clippy::single_match)]
78
#![allow(clippy::too_many_arguments)]

src/sql/bootstrap.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ CREATE TYPE svector;
88
CREATE TYPE bvector;
99
CREATE TYPE veci8;
1010
CREATE TYPE vector_index_stat;
11+
CREATE TYPE vector_accumulate_state;
1112

1213
-- bootstrap end

src/sql/finalize.sql

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ CREATE TYPE vector_index_stat AS (
7878
idx_options TEXT
7979
);
8080

81+
CREATE TYPE vector_accumulate_state (
82+
INPUT = _vectors_accumulate_state_in,
83+
OUTPUT = _vectors_accumulate_state_out,
84+
STORAGE = EXTERNAL,
85+
INTERNALLENGTH = VARIABLE,
86+
ALIGNMENT = double
87+
);
88+
8189
-- List of operators
8290

8391
CREATE OPERATOR + (
@@ -593,6 +601,30 @@ $$;
593601
CREATE FUNCTION alter_vector_index("index" OID, "key" TEXT, "value" TEXT) RETURNS void
594602
STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_alter_vector_index_wrapper';
595603

604+
CREATE FUNCTION vector_dims("v" vector) RETURNS INT
605+
STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_dims_wrapper';
606+
607+
CREATE FUNCTION vector_norm("v" vector) RETURNS real
608+
STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_norm_wrapper';
609+
610+
-- List of aggregates
611+
612+
CREATE AGGREGATE avg(vector) (
613+
SFUNC = _vectors_vector_accum,
614+
STYPE = vector_accumulate_state,
615+
COMBINEFUNC = _vectors_vector_combine,
616+
FINALFUNC = _vectors_vector_final,
617+
INITCOND = '0, []',
618+
PARALLEL = SAFE
619+
);
620+
621+
CREATE AGGREGATE sum(vector) (
622+
SFUNC = _vectors_vecf32_operator_add,
623+
STYPE = vector,
624+
COMBINEFUNC = _vectors_vecf32_operator_add,
625+
PARALLEL = SAFE
626+
);
627+
596628
-- List of casts
597629

598630
CREATE CAST (real[] AS vector)

0 commit comments

Comments
 (0)