1+ #![ allow( unused_lifetimes) ]
2+ #![ allow( clippy:: extra_unused_lifetimes) ]
13use crate :: datatype:: memory_vecf32:: { Vecf32Input , Vecf32Output } ;
24use crate :: error:: * ;
5+ use base:: operator:: { Operator , Vecf32Dot } ;
36use base:: scalar:: * ;
47use base:: vector:: * ;
8+ use pgrx:: pg_sys:: Datum ;
9+ use pgrx:: pg_sys:: Oid ;
10+ use pgrx:: pgrx_sql_entity_graph:: metadata:: ArgumentError ;
11+ use pgrx:: pgrx_sql_entity_graph:: metadata:: Returns ;
12+ use pgrx:: pgrx_sql_entity_graph:: metadata:: ReturnsError ;
13+ use pgrx:: pgrx_sql_entity_graph:: metadata:: SqlMapping ;
14+ use pgrx:: pgrx_sql_entity_graph:: metadata:: SqlTranslatable ;
15+ use pgrx:: { FromDatum , IntoDatum } ;
16+ use std:: alloc:: Layout ;
17+ use std:: ffi:: { CStr , CString } ;
18+ use std:: ops:: { Deref , DerefMut } ;
19+ use std:: ptr:: NonNull ;
20+
21+ #[ repr( C , align( 8 ) ) ]
22+ pub struct AccumulateStateHeader {
23+ varlena : u32 ,
24+ dims : u16 ,
25+ count : u64 ,
26+ phantom : [ F32 ; 0 ] ,
27+ }
28+
29+ impl AccumulateStateHeader {
30+ fn varlena ( size : usize ) -> u32 {
31+ ( size << 2 ) as u32
32+ }
33+ fn layout ( len : usize ) -> Layout {
34+ u16:: try_from ( len) . expect ( "Vector is too large." ) ;
35+ let layout_alpha = Layout :: new :: < AccumulateStateHeader > ( ) ;
36+ let layout_beta = Layout :: array :: < F32 > ( len) . unwrap ( ) ;
37+ let layout = layout_alpha. extend ( layout_beta) . unwrap ( ) . 0 ;
38+ layout. pad_to_align ( )
39+ }
40+ pub fn dims ( & self ) -> usize {
41+ self . dims as usize
42+ }
43+ pub fn count ( & self ) -> u64 {
44+ self . count
45+ }
46+ pub fn slice ( & self ) -> & [ F32 ] {
47+ unsafe { std:: slice:: from_raw_parts ( self . phantom . as_ptr ( ) , self . dims as usize ) }
48+ }
49+ pub fn slice_mut ( & mut self ) -> & mut [ F32 ] {
50+ unsafe { std:: slice:: from_raw_parts_mut ( self . phantom . as_mut_ptr ( ) , self . dims as usize ) }
51+ }
52+ }
53+
54+ pub enum AccumulateState < ' a > {
55+ Owned ( NonNull < AccumulateStateHeader > ) ,
56+ Borrowed ( & ' a mut AccumulateStateHeader ) ,
57+ }
58+
59+ impl < ' a > AccumulateState < ' a > {
60+ unsafe fn new ( p : NonNull < AccumulateStateHeader > ) -> Self {
61+ // datum maybe toasted, try to detoast it
62+ let q = unsafe {
63+ NonNull :: new ( pgrx:: pg_sys:: pg_detoast_datum ( p. as_ptr ( ) . cast ( ) ) . cast ( ) ) . unwrap ( )
64+ } ;
65+ if p != q {
66+ AccumulateState :: Owned ( q)
67+ } else {
68+ unsafe { AccumulateState :: Borrowed ( & mut * p. as_ptr ( ) ) }
69+ }
70+ }
71+
72+ pub fn new_with_slice ( count : u64 , slice : & [ F32 ] ) -> Self {
73+ let dims = slice. len ( ) ;
74+ let layout = AccumulateStateHeader :: layout ( dims) ;
75+ unsafe {
76+ let ptr = pgrx:: pg_sys:: palloc ( layout. size ( ) ) as * mut AccumulateStateHeader ;
77+ std:: ptr:: addr_of_mut!( ( * ptr) . varlena)
78+ . write ( AccumulateStateHeader :: varlena ( layout. size ( ) ) ) ;
79+ std:: ptr:: addr_of_mut!( ( * ptr) . dims) . write ( dims as u16 ) ;
80+ std:: ptr:: addr_of_mut!( ( * ptr) . count) . write ( count) ;
81+ if dims > 0 {
82+ std:: ptr:: copy_nonoverlapping ( slice. as_ptr ( ) , ( * ptr) . phantom . as_mut_ptr ( ) , dims) ;
83+ }
84+ AccumulateState :: Owned ( NonNull :: new ( ptr) . unwrap ( ) )
85+ }
86+ }
87+
88+ pub fn into_raw ( self ) -> * mut AccumulateStateHeader {
89+ let result = match self {
90+ AccumulateState :: Owned ( p) => p. as_ptr ( ) ,
91+ AccumulateState :: Borrowed ( ref p) => {
92+ * p as * const AccumulateStateHeader as * mut AccumulateStateHeader
93+ }
94+ } ;
95+ std:: mem:: forget ( self ) ;
96+ result
97+ }
98+ }
99+
100+ impl Deref for AccumulateState < ' _ > {
101+ type Target = AccumulateStateHeader ;
102+
103+ fn deref ( & self ) -> & Self :: Target {
104+ match self {
105+ AccumulateState :: Owned ( p) => unsafe { p. as_ref ( ) } ,
106+ AccumulateState :: Borrowed ( p) => p,
107+ }
108+ }
109+ }
110+
111+ impl DerefMut for AccumulateState < ' _ > {
112+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
113+ match self {
114+ AccumulateState :: Owned ( p) => unsafe { p. as_mut ( ) } ,
115+ AccumulateState :: Borrowed ( p) => p,
116+ }
117+ }
118+ }
119+
120+ impl Drop for AccumulateState < ' _ > {
121+ fn drop ( & mut self ) {
122+ match self {
123+ AccumulateState :: Owned ( p) => unsafe {
124+ pgrx:: pg_sys:: pfree ( p. as_ptr ( ) . cast ( ) ) ;
125+ } ,
126+ AccumulateState :: Borrowed ( _) => { }
127+ }
128+ }
129+ }
130+
131+ impl FromDatum for AccumulateState < ' _ > {
132+ unsafe fn from_polymorphic_datum ( datum : Datum , is_null : bool , _typmod : Oid ) -> Option < Self > {
133+ if is_null {
134+ None
135+ } else {
136+ let ptr = NonNull :: new ( datum. cast_mut_ptr :: < AccumulateStateHeader > ( ) ) . unwrap ( ) ;
137+ unsafe { Some ( AccumulateState :: new ( ptr) ) }
138+ }
139+ }
140+ }
141+
142+ impl IntoDatum for AccumulateState < ' _ > {
143+ fn into_datum ( self ) -> Option < Datum > {
144+ Some ( Datum :: from ( self . into_raw ( ) as * mut ( ) ) )
145+ }
146+
147+ fn type_oid ( ) -> Oid {
148+ let namespace = pgrx:: pg_catalog:: PgNamespace :: search_namespacename ( c"vectors" ) . unwrap ( ) ;
149+ let namespace = namespace. get ( ) . expect ( "pgvecto.rs is not installed." ) ;
150+ let t = pgrx:: pg_catalog:: PgType :: search_typenamensp ( c"accumulate_state" , namespace. oid ( ) )
151+ . unwrap ( ) ;
152+ let t = t. get ( ) . expect ( "pg_catalog is broken." ) ;
153+ t. oid ( )
154+ }
155+ }
156+
157+ unsafe impl SqlTranslatable for AccumulateState < ' _ > {
158+ fn argument_sql ( ) -> Result < SqlMapping , ArgumentError > {
159+ Ok ( SqlMapping :: As ( String :: from ( "accumulate_state" ) ) )
160+ }
161+ fn return_sql ( ) -> Result < Returns , ReturnsError > {
162+ Ok ( Returns :: One ( SqlMapping :: As ( String :: from (
163+ "accumulate_state" ,
164+ ) ) ) )
165+ }
166+ }
167+
168+ fn parse_accumulate_state ( input : & [ u8 ] ) -> Result < ( u64 , Vec < F32 > ) , String > {
169+ use crate :: utils:: parse:: parse_vector;
170+ let hint = "Invalid input format for accumulatestate, using \' bigint, array \' like \' 1, [1]\' " ;
171+ let ( count, slice) = input. split_once ( |& c| c == b',' ) . ok_or ( hint) ?;
172+ let count = std:: str:: from_utf8 ( count)
173+ . map_err ( |e| e. to_string ( ) + "\n " + hint) ?
174+ . parse :: < u64 > ( )
175+ . map_err ( |e| e. to_string ( ) + "\n " + hint) ?;
176+ let v = parse_vector ( slice, 0 , |s| s. parse ( ) . ok ( ) ) ;
177+ match v {
178+ Err ( e) => Err ( e. to_string ( ) + "\n " + hint) ,
179+ Ok ( vector) => Ok ( ( count, vector) ) ,
180+ }
181+ }
182+
183+ #[ pgrx:: pg_extern( immutable, strict, parallel_safe) ]
184+ fn _vectors_accumulate_state_in ( input : & CStr , _oid : Oid , _typmod : i32 ) -> AccumulateState < ' _ > {
185+ // parse one bigint and a vector of f32, split with a comma
186+ let res = parse_accumulate_state ( input. to_bytes ( ) ) ;
187+ match res {
188+ Err ( e) => {
189+ bad_literal ( & e. to_string ( ) ) ;
190+ }
191+ Ok ( ( count, vector) ) => AccumulateState :: new_with_slice ( count, & vector) ,
192+ }
193+ }
194+
195+ #[ pgrx:: pg_extern( immutable, strict, parallel_safe) ]
196+ fn _vectors_accumulate_state_out ( state : AccumulateState < ' _ > ) -> CString {
197+ let mut buffer = String :: new ( ) ;
198+ buffer. push_str ( format ! ( "{}, " , state. count( ) ) . as_str ( ) ) ;
199+ buffer. push ( '[' ) ;
200+ if let Some ( & x) = state. slice ( ) . first ( ) {
201+ buffer. push_str ( format ! ( "{}" , x) . as_str ( ) ) ;
202+ }
203+ for & x in state. slice ( ) . iter ( ) . skip ( 1 ) {
204+ buffer. push_str ( format ! ( ", {}" , x) . as_str ( ) ) ;
205+ }
206+ buffer. push ( ']' ) ;
207+ CString :: new ( buffer) . unwrap ( )
208+ }
5209
6210/// accumulate intermediate state for vector average
7211#[ pgrx:: pg_extern( immutable, strict, parallel_safe) ]
8- fn _vectors_vector_accum (
9- mut state : pgrx :: composite_type! ( ' static , "vectors.vector_accum_state" ) ,
212+ fn _vectors_vector_accum < ' a > (
213+ mut state : AccumulateState < ' a > ,
10214 value : Vecf32Input < ' _ > ,
11- ) -> pgrx:: composite_type!( ' static , "vectors.vector_accum_state" ) {
12- let count = state
13- . get_by_name :: < i64 > ( "count" )
14- . unwrap ( )
15- . unwrap_or_default ( ) ;
16- if count == 0 {
17- // state is empty
18- let mut result =
19- pgrx:: heap_tuple:: PgHeapTuple :: new_composite_type ( "vectors.vector_accum_state" )
20- . unwrap ( ) ;
21- let sum = value. iter ( ) . map ( |x| x. 0 as f64 ) . collect :: < Vec < _ > > ( ) ;
22- result. set_by_name ( "count" , count + 1 ) . unwrap ( ) ;
23- result. set_by_name ( "sum" , sum) . unwrap ( ) ;
24- result
25- } else {
26- let sum = state
27- . get_by_name :: < pgrx:: Array < f64 > > ( "sum" )
28- . unwrap ( )
29- . unwrap ( ) ;
30- check_matched_dims ( sum. len ( ) , value. dims ( ) ) ;
31- // TODO: pgrx::Array<T> don't support mutable operations currently, we can reuse the state once it's supported.
32- let sum = sum
33- . iter_deny_null ( )
34- . zip ( value. iter ( ) )
35- . map ( |( x, y) | x + y. 0 as f64 )
36- . collect :: < Vec < _ > > ( ) ;
37- state. set_by_name ( "count" , count + 1 ) . unwrap ( ) ;
38- state. set_by_name ( "sum" , sum) . unwrap ( ) ;
39- state
215+ ) -> AccumulateState < ' a > {
216+ let count = state. count ( ) ;
217+ match count {
218+ // if the state is empty, copy the input vector
219+ 0 => AccumulateState :: new_with_slice ( 1 , value. iter ( ) . as_slice ( ) ) ,
220+ _ => {
221+ let dims = state. dims ( ) ;
222+ let value_dims = value. dims ( ) ;
223+ check_matched_dims ( dims, value_dims) ;
224+ let sum = state. slice_mut ( ) ;
225+ // accumulate the input vector
226+ for ( x, y) in sum. iter_mut ( ) . zip ( value. iter ( ) ) {
227+ * x += * y;
228+ }
229+ // increase the count
230+ state. count += 1 ;
231+ state
232+ }
40233 }
41234}
42235
236+ /// combine two intermediate states for vector average
43237#[ pgrx:: pg_extern( immutable, strict, parallel_safe) ]
44- fn _vectors_vector_combine (
45- state1 : pgrx:: composite_type!( ' static , "vectors.vector_accum_state" ) ,
46- state2 : pgrx:: composite_type!( ' static , "vectors.vector_accum_state" ) ,
47- ) -> pgrx:: composite_type!( ' static , "vectors.vector_accum_state" ) {
48- let count1 = state1
49- . get_by_name :: < i64 > ( "count" )
50- . unwrap ( )
51- . unwrap_or_default ( ) ;
52- let count2 = state2
53- . get_by_name :: < i64 > ( "count" )
54- . unwrap ( )
55- . unwrap_or_default ( ) ;
238+ fn _vectors_vector_combine < ' a > (
239+ mut state1 : AccumulateState < ' a > ,
240+ state2 : AccumulateState < ' a > ,
241+ ) -> AccumulateState < ' a > {
242+ let count1 = state1. count ( ) ;
243+ let count2 = state2. count ( ) ;
56244 if count1 == 0 {
57245 state2
58246 } else if count2 == 0 {
59247 state1
60248 } else {
61- let sum1 = state1
62- . get_by_name :: < pgrx:: Array < f64 > > ( "sum" )
63- . unwrap ( )
64- . unwrap ( ) ;
65- let sum2 = state2
66- . get_by_name :: < pgrx:: Array < f64 > > ( "sum" )
67- . unwrap ( )
68- . unwrap ( ) ;
69- check_matched_dims ( sum1. len ( ) , sum2. len ( ) ) ;
70- let sum = sum1
71- . iter_deny_null ( )
72- . zip ( sum2. iter_deny_null ( ) )
73- . map ( |( x, y) | x + y)
74- . collect :: < Vec < _ > > ( ) ;
75- let mut result =
76- pgrx:: heap_tuple:: PgHeapTuple :: new_composite_type ( "vectors.vector_accum_state" )
77- . unwrap ( ) ;
78- // merge two accumulate states
79- result. set_by_name ( "count" , count1 + count2) . unwrap ( ) ;
80- result. set_by_name ( "sum" , sum) . unwrap ( ) ;
81- result
249+ let dims1 = state1. dims ( ) ;
250+ let dims2 = state2. dims ( ) ;
251+ check_matched_dims ( dims1, dims2) ;
252+ state1. count += count2;
253+ let sum1 = state1. slice_mut ( ) ;
254+ let sum2 = state2. slice ( ) ;
255+ for ( x, y) in sum1. iter_mut ( ) . zip ( sum2. iter ( ) ) {
256+ * x += * y;
257+ }
258+ state1
82259 }
83260}
84261
262+ /// finalize the intermediate state for vector average
85263#[ pgrx:: pg_extern( immutable, strict, parallel_safe) ]
86- fn _vectors_vector_final (
87- state : pgrx:: composite_type!( ' static , "vectors.vector_accum_state" ) ,
88- ) -> Option < Vecf32Output > {
89- let count = state
90- . get_by_name :: < i64 > ( "count" )
91- . unwrap ( )
92- . unwrap_or_default ( ) ;
93- // return null datum if all inputs vector are null
264+ fn _vectors_vector_final ( state : AccumulateState < ' _ > ) -> Option < Vecf32Output > {
265+ let count = state. count ( ) ;
94266 if count == 0 {
267+ // return NULL if all inputs are NULL
95268 return None ;
96269 }
97270 let sum = state
98- . get_by_name :: < pgrx:: Array < f64 > > ( "sum" )
99- . unwrap ( )
100- . unwrap ( ) ;
101- // compute the average of vectors by dividing the sum by the count
102- let sum = sum
103- . iter_deny_null ( )
104- . map ( |x| F32 ( ( x / count as f64 ) as f32 ) )
271+ . slice ( )
272+ . iter ( )
273+ . map ( |x| * x / F32 ( count as f32 ) )
105274 . collect :: < Vec < _ > > ( ) ;
106275 Some ( Vecf32Output :: new (
107276 Vecf32Borrowed :: new_checked ( & sum) . unwrap ( ) ,
@@ -117,9 +286,7 @@ fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 {
117286/// Calculate the l2 norm of a vector.
118287#[ pgrx:: pg_extern( immutable, strict, parallel_safe) ]
119288fn _vectors_vector_norm ( vector : Vecf32Input < ' _ > ) -> f32 {
120- vector
121- . iter ( )
122- . map ( |x : & F32 | x. 0 as f64 * x. 0 as f64 )
123- . sum :: < f64 > ( )
124- . sqrt ( ) as f32
289+ Vecf32Dot :: distance ( vector. for_borrow ( ) , vector. for_borrow ( ) )
290+ . to_f32 ( )
291+ . sqrt ( )
125292}
0 commit comments