Skip to content

Commit fdbfebf

Browse files
authored
feat: Add VEC_PRODUCT, VEC_ELEM_PRODUCT, VEC_NORM. (#5303)
* feat: Add `vec_product(col)` function. * feat: Add `vec_elem_product` function * feat: Add `vec_norm` function.
1 parent 812a775 commit fdbfebf

File tree

9 files changed

+686
-0
lines changed

9 files changed

+686
-0
lines changed

src/common/function/src/scalars/aggregate.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator;
3232
pub use scipy_stats_norm_pdf::ScipyStatsNormPdfAccumulatorCreator;
3333

3434
use crate::function_registry::FunctionRegistry;
35+
use crate::scalars::vector::product::VectorProductCreator;
3536
use crate::scalars::vector::sum::VectorSumCreator;
3637

3738
/// A function creates `AggregateFunctionCreator`.
@@ -93,6 +94,7 @@ impl AggregateFunctions {
9394
register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator);
9495
register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator);
9596
register_aggr_func!("vec_sum", 1, VectorSumCreator);
97+
register_aggr_func!("vec_product", 1, VectorProductCreator);
9698

9799
#[cfg(feature = "geo")]
98100
register_aggr_func!(

src/common/function/src/scalars/vector.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
mod convert;
1616
mod distance;
17+
mod elem_product;
1718
mod elem_sum;
1819
pub mod impl_conv;
20+
pub(crate) mod product;
1921
mod scalar_add;
2022
mod scalar_mul;
2123
mod sub;
2224
pub(crate) mod sum;
2325
mod vector_div;
2426
mod vector_mul;
27+
mod vector_norm;
2528

2629
use std::sync::Arc;
2730

@@ -46,8 +49,10 @@ impl VectorFunction {
4649

4750
// vector calculation
4851
registry.register(Arc::new(vector_mul::VectorMulFunction));
52+
registry.register(Arc::new(vector_norm::VectorNormFunction));
4953
registry.register(Arc::new(vector_div::VectorDivFunction));
5054
registry.register(Arc::new(sub::SubFunction));
5155
registry.register(Arc::new(elem_sum::ElemSumFunction));
56+
registry.register(Arc::new(elem_product::ElemProductFunction));
5257
}
5358
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright 2023 Greptime Team
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::borrow::Cow;
16+
use std::fmt::Display;
17+
18+
use common_query::error::InvalidFuncArgsSnafu;
19+
use common_query::prelude::{Signature, TypeSignature, Volatility};
20+
use datatypes::prelude::ConcreteDataType;
21+
use datatypes::scalars::ScalarVectorBuilder;
22+
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23+
use nalgebra::DVectorView;
24+
use snafu::ensure;
25+
26+
use crate::function::{Function, FunctionContext};
27+
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28+
29+
const NAME: &str = "vec_elem_product";
30+
31+
/// Multiplies all elements of the vector, returns a scalar.
32+
///
33+
/// # Example
34+
///
35+
/// ```sql
36+
/// SELECT vec_elem_product(parse_vec('[1.0, 2.0, 3.0, 4.0]'));
37+
///
38+
// +-----------------------------------------------------------+
39+
// | vec_elem_product(parse_vec(Utf8("[1.0, 2.0, 3.0, 4.0]"))) |
40+
// +-----------------------------------------------------------+
41+
// | 24.0 |
42+
// +-----------------------------------------------------------+
43+
/// ``````
44+
#[derive(Debug, Clone, Default)]
45+
pub struct ElemProductFunction;
46+
47+
impl Function for ElemProductFunction {
48+
fn name(&self) -> &str {
49+
NAME
50+
}
51+
52+
fn return_type(
53+
&self,
54+
_input_types: &[ConcreteDataType],
55+
) -> common_query::error::Result<ConcreteDataType> {
56+
Ok(ConcreteDataType::float32_datatype())
57+
}
58+
59+
fn signature(&self) -> Signature {
60+
Signature::one_of(
61+
vec![
62+
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
63+
TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
64+
],
65+
Volatility::Immutable,
66+
)
67+
}
68+
69+
fn eval(
70+
&self,
71+
_func_ctx: FunctionContext,
72+
columns: &[VectorRef],
73+
) -> common_query::error::Result<VectorRef> {
74+
ensure!(
75+
columns.len() == 1,
76+
InvalidFuncArgsSnafu {
77+
err_msg: format!(
78+
"The length of the args is not correct, expect exactly one, have: {}",
79+
columns.len()
80+
)
81+
}
82+
);
83+
let arg0 = &columns[0];
84+
85+
let len = arg0.len();
86+
let mut result = Float32VectorBuilder::with_capacity(len);
87+
if len == 0 {
88+
return Ok(result.to_vector());
89+
}
90+
91+
let arg0_const = as_veclit_if_const(arg0)?;
92+
93+
for i in 0..len {
94+
let arg0 = match arg0_const.as_ref() {
95+
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
96+
None => as_veclit(arg0.get_ref(i))?,
97+
};
98+
let Some(arg0) = arg0 else {
99+
result.push_null();
100+
continue;
101+
};
102+
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
103+
}
104+
105+
Ok(result.to_vector())
106+
}
107+
}
108+
109+
impl Display for ElemProductFunction {
110+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111+
write!(f, "{}", NAME.to_ascii_uppercase())
112+
}
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use std::sync::Arc;
118+
119+
use datatypes::vectors::StringVector;
120+
121+
use super::*;
122+
use crate::function::FunctionContext;
123+
124+
#[test]
125+
fn test_elem_product() {
126+
let func = ElemProductFunction;
127+
128+
let input0 = Arc::new(StringVector::from(vec![
129+
Some("[1.0,2.0,3.0]".to_string()),
130+
Some("[4.0,5.0,6.0]".to_string()),
131+
None,
132+
]));
133+
134+
let result = func.eval(FunctionContext::default(), &[input0]).unwrap();
135+
136+
let result = result.as_ref();
137+
assert_eq!(result.len(), 3);
138+
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
139+
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
140+
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
141+
}
142+
}
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// Copyright 2023 Greptime Team
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::sync::Arc;
16+
17+
use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18+
use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19+
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
20+
use common_query::prelude::AccumulatorCreatorFunction;
21+
use datatypes::prelude::{ConcreteDataType, Value, *};
22+
use datatypes::vectors::VectorRef;
23+
use nalgebra::{Const, DVectorView, Dyn, OVector};
24+
use snafu::ensure;
25+
26+
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27+
28+
/// Aggregates by multiplying elements across the same dimension, returns a vector.
29+
#[derive(Debug, Default)]
30+
pub struct VectorProduct {
31+
product: Option<OVector<f32, Dyn>>,
32+
has_null: bool,
33+
}
34+
35+
#[as_aggr_func_creator]
36+
#[derive(Debug, Default, AggrFuncTypeStore)]
37+
pub struct VectorProductCreator {}
38+
39+
impl AggregateFunctionCreator for VectorProductCreator {
40+
fn creator(&self) -> AccumulatorCreatorFunction {
41+
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
42+
ensure!(
43+
types.len() == 1,
44+
InvalidFuncArgsSnafu {
45+
err_msg: format!(
46+
"The length of the args is not correct, expect exactly one, have: {}",
47+
types.len()
48+
)
49+
}
50+
);
51+
let input_type = &types[0];
52+
match input_type {
53+
ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
54+
Ok(Box::new(VectorProduct::default()))
55+
}
56+
_ => {
57+
let err_msg = format!(
58+
"\"VEC_PRODUCT\" aggregate function not support data type {:?}",
59+
input_type.logical_type_id(),
60+
);
61+
CreateAccumulatorSnafu { err_msg }.fail()?
62+
}
63+
}
64+
});
65+
creator
66+
}
67+
68+
fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
69+
Ok(ConcreteDataType::binary_datatype())
70+
}
71+
72+
fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
73+
Ok(vec![self.output_type()?])
74+
}
75+
}
76+
77+
impl VectorProduct {
78+
fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
79+
self.product.get_or_insert_with(|| {
80+
OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
81+
})
82+
}
83+
84+
fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
85+
if values.is_empty() || self.has_null {
86+
return Ok(());
87+
};
88+
let column = &values[0];
89+
let len = column.len();
90+
91+
match as_veclit_if_const(column)? {
92+
Some(column) => {
93+
let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
94+
*self.inner(vec_column.len()) =
95+
(*self.inner(vec_column.len())).component_mul(&vec_column);
96+
}
97+
None => {
98+
for i in 0..len {
99+
let Some(arg0) = as_veclit(column.get_ref(i))? else {
100+
if is_update {
101+
self.has_null = true;
102+
self.product = None;
103+
}
104+
return Ok(());
105+
};
106+
let vec_column = DVectorView::from_slice(&arg0, arg0.len());
107+
*self.inner(vec_column.len()) =
108+
(*self.inner(vec_column.len())).component_mul(&vec_column);
109+
}
110+
}
111+
}
112+
Ok(())
113+
}
114+
}
115+
116+
impl Accumulator for VectorProduct {
117+
fn state(&self) -> common_query::error::Result<Vec<Value>> {
118+
self.evaluate().map(|v| vec![v])
119+
}
120+
121+
fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
122+
self.update(values, true)
123+
}
124+
125+
fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
126+
self.update(states, false)
127+
}
128+
129+
fn evaluate(&self) -> common_query::error::Result<Value> {
130+
match &self.product {
131+
None => Ok(Value::Null),
132+
Some(vector) => {
133+
let v = vector.as_slice();
134+
Ok(Value::from(veclit_to_binlit(v)))
135+
}
136+
}
137+
}
138+
}
139+
140+
#[cfg(test)]
141+
mod tests {
142+
use std::sync::Arc;
143+
144+
use datatypes::vectors::{ConstantVector, StringVector};
145+
146+
use super::*;
147+
148+
#[test]
149+
fn test_update_batch() {
150+
// test update empty batch, expect not updating anything
151+
let mut vec_product = VectorProduct::default();
152+
vec_product.update_batch(&[]).unwrap();
153+
assert!(vec_product.product.is_none());
154+
assert!(!vec_product.has_null);
155+
assert_eq!(Value::Null, vec_product.evaluate().unwrap());
156+
157+
// test update one not-null value
158+
let mut vec_product = VectorProduct::default();
159+
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
160+
"[1.0,2.0,3.0]".to_string(),
161+
)]))];
162+
vec_product.update_batch(&v).unwrap();
163+
assert_eq!(
164+
Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
165+
vec_product.evaluate().unwrap()
166+
);
167+
168+
// test update one null value
169+
let mut vec_product = VectorProduct::default();
170+
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
171+
vec_product.update_batch(&v).unwrap();
172+
assert_eq!(Value::Null, vec_product.evaluate().unwrap());
173+
174+
// test update no null-value batch
175+
let mut vec_product = VectorProduct::default();
176+
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
177+
Some("[1.0,2.0,3.0]".to_string()),
178+
Some("[4.0,5.0,6.0]".to_string()),
179+
Some("[7.0,8.0,9.0]".to_string()),
180+
]))];
181+
vec_product.update_batch(&v).unwrap();
182+
assert_eq!(
183+
Value::from(veclit_to_binlit(&[28.0, 80.0, 162.0])),
184+
vec_product.evaluate().unwrap()
185+
);
186+
187+
// test update null-value batch
188+
let mut vec_product = VectorProduct::default();
189+
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
190+
Some("[1.0,2.0,3.0]".to_string()),
191+
None,
192+
Some("[7.0,8.0,9.0]".to_string()),
193+
]))];
194+
vec_product.update_batch(&v).unwrap();
195+
assert_eq!(Value::Null, vec_product.evaluate().unwrap());
196+
197+
// test update with constant vector
198+
let mut vec_product = VectorProduct::default();
199+
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
200+
Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
201+
4,
202+
))];
203+
204+
vec_product.update_batch(&v).unwrap();
205+
206+
assert_eq!(
207+
Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
208+
vec_product.evaluate().unwrap()
209+
);
210+
}
211+
}

0 commit comments

Comments
 (0)