Skip to content

Commit 4f8bc4f

Browse files
update AccumulateState
Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com>
1 parent b8aad70 commit 4f8bc4f

File tree

5 files changed

+270
-111
lines changed

5 files changed

+270
-111
lines changed

src/datatype/functions_vecf32.rs

Lines changed: 251 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,276 @@
1+
#![allow(unused_lifetimes)]
2+
#![allow(clippy::extra_unused_lifetimes)]
13
use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output};
24
use crate::error::*;
5+
use base::operator::{Operator, Vecf32Dot};
36
use base::scalar::*;
47
use base::vector::*;
8+
use pgrx::pg_sys::Datum;
9+
use pgrx::pg_sys::Oid;
10+
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
11+
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
12+
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
13+
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
14+
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
15+
use pgrx::{FromDatum, IntoDatum};
16+
use std::alloc::Layout;
17+
use std::ffi::{CStr, CString};
18+
use std::ops::{Deref, DerefMut};
19+
use std::ptr::NonNull;
20+
21+
#[repr(C, align(8))]
22+
pub struct AccumulateStateHeader {
23+
varlena: u32,
24+
dims: u16,
25+
count: u64,
26+
phantom: [F32; 0],
27+
}
28+
29+
impl AccumulateStateHeader {
30+
fn varlena(size: usize) -> u32 {
31+
(size << 2) as u32
32+
}
33+
fn layout(len: usize) -> Layout {
34+
u16::try_from(len).expect("Vector is too large.");
35+
let layout_alpha = Layout::new::<AccumulateStateHeader>();
36+
let layout_beta = Layout::array::<F32>(len).unwrap();
37+
let layout = layout_alpha.extend(layout_beta).unwrap().0;
38+
layout.pad_to_align()
39+
}
40+
pub fn dims(&self) -> usize {
41+
self.dims as usize
42+
}
43+
pub fn count(&self) -> u64 {
44+
self.count
45+
}
46+
pub fn slice(&self) -> &[F32] {
47+
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) }
48+
}
49+
pub fn slice_mut(&mut self) -> &mut [F32] {
50+
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.dims as usize) }
51+
}
52+
}
53+
54+
pub enum AccumulateState<'a> {
55+
Owned(NonNull<AccumulateStateHeader>),
56+
Borrowed(&'a mut AccumulateStateHeader),
57+
}
58+
59+
impl<'a> AccumulateState<'a> {
60+
unsafe fn new(p: NonNull<AccumulateStateHeader>) -> Self {
61+
// datum maybe toasted, try to detoast it
62+
let q = unsafe {
63+
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap()
64+
};
65+
if p != q {
66+
AccumulateState::Owned(q)
67+
} else {
68+
unsafe { AccumulateState::Borrowed(&mut *p.as_ptr()) }
69+
}
70+
}
71+
72+
pub fn new_with_slice(count: u64, slice: &[F32]) -> Self {
73+
let dims = slice.len();
74+
let layout = AccumulateStateHeader::layout(dims);
75+
unsafe {
76+
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut AccumulateStateHeader;
77+
std::ptr::addr_of_mut!((*ptr).varlena)
78+
.write(AccumulateStateHeader::varlena(layout.size()));
79+
std::ptr::addr_of_mut!((*ptr).dims).write(dims as u16);
80+
std::ptr::addr_of_mut!((*ptr).count).write(count);
81+
if dims > 0 {
82+
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), dims);
83+
}
84+
AccumulateState::Owned(NonNull::new(ptr).unwrap())
85+
}
86+
}
87+
88+
pub fn into_raw(self) -> *mut AccumulateStateHeader {
89+
let result = match self {
90+
AccumulateState::Owned(p) => p.as_ptr(),
91+
AccumulateState::Borrowed(ref p) => {
92+
*p as *const AccumulateStateHeader as *mut AccumulateStateHeader
93+
}
94+
};
95+
std::mem::forget(self);
96+
result
97+
}
98+
}
99+
100+
impl Deref for AccumulateState<'_> {
101+
type Target = AccumulateStateHeader;
102+
103+
fn deref(&self) -> &Self::Target {
104+
match self {
105+
AccumulateState::Owned(p) => unsafe { p.as_ref() },
106+
AccumulateState::Borrowed(p) => p,
107+
}
108+
}
109+
}
110+
111+
impl DerefMut for AccumulateState<'_> {
112+
fn deref_mut(&mut self) -> &mut Self::Target {
113+
match self {
114+
AccumulateState::Owned(p) => unsafe { p.as_mut() },
115+
AccumulateState::Borrowed(p) => p,
116+
}
117+
}
118+
}
119+
120+
impl Drop for AccumulateState<'_> {
121+
fn drop(&mut self) {
122+
match self {
123+
AccumulateState::Owned(p) => unsafe {
124+
pgrx::pg_sys::pfree(p.as_ptr().cast());
125+
},
126+
AccumulateState::Borrowed(_) => {}
127+
}
128+
}
129+
}
130+
131+
impl FromDatum for AccumulateState<'_> {
132+
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typmod: Oid) -> Option<Self> {
133+
if is_null {
134+
None
135+
} else {
136+
let ptr = NonNull::new(datum.cast_mut_ptr::<AccumulateStateHeader>()).unwrap();
137+
unsafe { Some(AccumulateState::new(ptr)) }
138+
}
139+
}
140+
}
141+
142+
impl IntoDatum for AccumulateState<'_> {
143+
fn into_datum(self) -> Option<Datum> {
144+
Some(Datum::from(self.into_raw() as *mut ()))
145+
}
146+
147+
fn type_oid() -> Oid {
148+
let namespace = pgrx::pg_catalog::PgNamespace::search_namespacename(c"vectors").unwrap();
149+
let namespace = namespace.get().expect("pgvecto.rs is not installed.");
150+
let t = pgrx::pg_catalog::PgType::search_typenamensp(c"accumulate_state", namespace.oid())
151+
.unwrap();
152+
let t = t.get().expect("pg_catalog is broken.");
153+
t.oid()
154+
}
155+
}
156+
157+
unsafe impl SqlTranslatable for AccumulateState<'_> {
158+
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
159+
Ok(SqlMapping::As(String::from("accumulate_state")))
160+
}
161+
fn return_sql() -> Result<Returns, ReturnsError> {
162+
Ok(Returns::One(SqlMapping::As(String::from(
163+
"accumulate_state",
164+
))))
165+
}
166+
}
167+
168+
fn parse_accumulate_state(input: &[u8]) -> Result<(u64, Vec<F32>), String> {
169+
use crate::utils::parse::parse_vector;
170+
let hint = "Invalid input format for accumulatestate, using \'bigint, array \' like \'1, [1]\'";
171+
let (count, slice) = input.split_once(|&c| c == b',').ok_or(hint)?;
172+
let count = std::str::from_utf8(count)
173+
.map_err(|e| e.to_string() + "\n" + hint)?
174+
.parse::<u64>()
175+
.map_err(|e| e.to_string() + "\n" + hint)?;
176+
let v = parse_vector(slice, 0, |s| s.parse().ok());
177+
match v {
178+
Err(e) => Err(e.to_string() + "\n" + hint),
179+
Ok(vector) => Ok((count, vector)),
180+
}
181+
}
182+
183+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
184+
fn _vectors_accumulate_state_in(input: &CStr, _oid: Oid, _typmod: i32) -> AccumulateState<'_> {
185+
// parse one bigint and a vector of f32, split with a comma
186+
let res = parse_accumulate_state(input.to_bytes());
187+
match res {
188+
Err(e) => {
189+
bad_literal(&e.to_string());
190+
}
191+
Ok((count, vector)) => AccumulateState::new_with_slice(count, &vector),
192+
}
193+
}
194+
195+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
196+
fn _vectors_accumulate_state_out(state: AccumulateState<'_>) -> CString {
197+
let mut buffer = String::new();
198+
buffer.push_str(format!("{}, ", state.count()).as_str());
199+
buffer.push('[');
200+
if let Some(&x) = state.slice().first() {
201+
buffer.push_str(format!("{}", x).as_str());
202+
}
203+
for &x in state.slice().iter().skip(1) {
204+
buffer.push_str(format!(", {}", x).as_str());
205+
}
206+
buffer.push(']');
207+
CString::new(buffer).unwrap()
208+
}
5209

6210
/// accumulate intermediate state for vector average
7211
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
8-
fn _vectors_vector_accum(
9-
mut state: pgrx::composite_type!('static, "vectors.vector_accum_state"),
212+
fn _vectors_vector_accum<'a>(
213+
mut state: AccumulateState<'a>,
10214
value: Vecf32Input<'_>,
11-
) -> pgrx::composite_type!('static, "vectors.vector_accum_state") {
12-
let count = state
13-
.get_by_name::<i64>("count")
14-
.unwrap()
15-
.unwrap_or_default();
16-
if count == 0 {
17-
// state is empty
18-
let mut result =
19-
pgrx::heap_tuple::PgHeapTuple::new_composite_type("vectors.vector_accum_state")
20-
.unwrap();
21-
let sum = value.iter().map(|x| x.0 as f64).collect::<Vec<_>>();
22-
result.set_by_name("count", count + 1).unwrap();
23-
result.set_by_name("sum", sum).unwrap();
24-
result
25-
} else {
26-
let sum = state
27-
.get_by_name::<pgrx::Array<f64>>("sum")
28-
.unwrap()
29-
.unwrap();
30-
check_matched_dims(sum.len(), value.dims());
31-
// TODO: pgrx::Array<T> don't support mutable operations currently, we can reuse the state once it's supported.
32-
let sum = sum
33-
.iter_deny_null()
34-
.zip(value.iter())
35-
.map(|(x, y)| x + y.0 as f64)
36-
.collect::<Vec<_>>();
37-
state.set_by_name("count", count + 1).unwrap();
38-
state.set_by_name("sum", sum).unwrap();
39-
state
215+
) -> AccumulateState<'a> {
216+
let count = state.count();
217+
match count {
218+
// if the state is empty, copy the input vector
219+
0 => AccumulateState::new_with_slice(1, value.iter().as_slice()),
220+
_ => {
221+
let dims = state.dims();
222+
let value_dims = value.dims();
223+
check_matched_dims(dims, value_dims);
224+
let sum = state.slice_mut();
225+
// accumulate the input vector
226+
for (x, y) in sum.iter_mut().zip(value.iter()) {
227+
*x += *y;
228+
}
229+
// increase the count
230+
state.count += 1;
231+
state
232+
}
40233
}
41234
}
42235

236+
/// combine two intermediate states for vector average
43237
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
44-
fn _vectors_vector_combine(
45-
state1: pgrx::composite_type!('static, "vectors.vector_accum_state"),
46-
state2: pgrx::composite_type!('static, "vectors.vector_accum_state"),
47-
) -> pgrx::composite_type!('static, "vectors.vector_accum_state") {
48-
let count1 = state1
49-
.get_by_name::<i64>("count")
50-
.unwrap()
51-
.unwrap_or_default();
52-
let count2 = state2
53-
.get_by_name::<i64>("count")
54-
.unwrap()
55-
.unwrap_or_default();
238+
fn _vectors_vector_combine<'a>(
239+
mut state1: AccumulateState<'a>,
240+
state2: AccumulateState<'a>,
241+
) -> AccumulateState<'a> {
242+
let count1 = state1.count();
243+
let count2 = state2.count();
56244
if count1 == 0 {
57245
state2
58246
} else if count2 == 0 {
59247
state1
60248
} else {
61-
let sum1 = state1
62-
.get_by_name::<pgrx::Array<f64>>("sum")
63-
.unwrap()
64-
.unwrap();
65-
let sum2 = state2
66-
.get_by_name::<pgrx::Array<f64>>("sum")
67-
.unwrap()
68-
.unwrap();
69-
check_matched_dims(sum1.len(), sum2.len());
70-
let sum = sum1
71-
.iter_deny_null()
72-
.zip(sum2.iter_deny_null())
73-
.map(|(x, y)| x + y)
74-
.collect::<Vec<_>>();
75-
let mut result =
76-
pgrx::heap_tuple::PgHeapTuple::new_composite_type("vectors.vector_accum_state")
77-
.unwrap();
78-
// merge two accumulate states
79-
result.set_by_name("count", count1 + count2).unwrap();
80-
result.set_by_name("sum", sum).unwrap();
81-
result
249+
let dims1 = state1.dims();
250+
let dims2 = state2.dims();
251+
check_matched_dims(dims1, dims2);
252+
state1.count += count2;
253+
let sum1 = state1.slice_mut();
254+
let sum2 = state2.slice();
255+
for (x, y) in sum1.iter_mut().zip(sum2.iter()) {
256+
*x += *y;
257+
}
258+
state1
82259
}
83260
}
84261

262+
/// finalize the intermediate state for vector average
85263
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
86-
fn _vectors_vector_final(
87-
state: pgrx::composite_type!('static, "vectors.vector_accum_state"),
88-
) -> Option<Vecf32Output> {
89-
let count = state
90-
.get_by_name::<i64>("count")
91-
.unwrap()
92-
.unwrap_or_default();
93-
// return null datum if all inputs vector are null
264+
fn _vectors_vector_final(state: AccumulateState<'_>) -> Option<Vecf32Output> {
265+
let count = state.count();
94266
if count == 0 {
267+
// return NULL if all inputs are NULL
95268
return None;
96269
}
97270
let sum = state
98-
.get_by_name::<pgrx::Array<f64>>("sum")
99-
.unwrap()
100-
.unwrap();
101-
// compute the average of vectors by dividing the sum by the count
102-
let sum = sum
103-
.iter_deny_null()
104-
.map(|x| F32((x / count as f64) as f32))
271+
.slice()
272+
.iter()
273+
.map(|x| *x / F32(count as f32))
105274
.collect::<Vec<_>>();
106275
Some(Vecf32Output::new(
107276
Vecf32Borrowed::new_checked(&sum).unwrap(),
@@ -117,9 +286,7 @@ fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 {
117286
/// Calculate the l2 norm of a vector.
118287
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
119288
fn _vectors_vector_norm(vector: Vecf32Input<'_>) -> f32 {
120-
vector
121-
.iter()
122-
.map(|x: &F32| x.0 as f64 * x.0 as f64)
123-
.sum::<f64>()
124-
.sqrt() as f32
289+
Vecf32Dot::distance(vector.for_borrow(), vector.for_borrow())
290+
.to_f32()
291+
.sqrt()
125292
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//!
33
//! Provides an easy-to-use extension for vector similarity search.
44
#![feature(alloc_error_hook)]
5+
#![feature(slice_split_once)]
56
#![allow(clippy::needless_range_loop)]
67
#![allow(clippy::single_match)]
78
#![allow(clippy::too_many_arguments)]

src/sql/bootstrap.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ CREATE TYPE svector;
88
CREATE TYPE bvector;
99
CREATE TYPE veci8;
1010
CREATE TYPE vector_index_stat;
11-
CREATE TYPE vector_accum_state;
11+
CREATE TYPE accumulate_state;
1212

1313
-- bootstrap end

0 commit comments

Comments
 (0)