Skip to content

Commit

Permalink
add l2_norm for VectorBorrowed.
Browse files Browse the repository at this point in the history
  • Loading branch information
my-vegetable-has-exploded committed Apr 19, 2024
1 parent c28e1d1 commit da65beb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
3 changes: 3 additions & 0 deletions crates/base/src/vector/vecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ impl<'a> Vecf32Borrowed<'a> {
pub fn slice(&self) -> &[F32] {
self.0
}
pub fn l2_norm(&self) -> F32 {
dot(self.slice(), self.slice()).sqrt()
}
}

impl<'a> VectorBorrowed for Vecf32Borrowed<'a> {
Expand Down
14 changes: 8 additions & 6 deletions src/datatype/functions_vecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#![allow(clippy::extra_unused_lifetimes)]
use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output};
use crate::error::*;
use base::operator::{Operator, Vecf32Dot};
use base::scalar::*;
use base::vector::*;
use pgrx::pg_sys::Datum;
Expand Down Expand Up @@ -147,20 +146,23 @@ impl IntoDatum for AccumulateState<'_> {
fn type_oid() -> Oid {
let namespace = pgrx::pg_catalog::PgNamespace::search_namespacename(c"vectors").unwrap();
let namespace = namespace.get().expect("pgvecto.rs is not installed.");
let t = pgrx::pg_catalog::PgType::search_typenamensp(c"accumulate_state", namespace.oid())
.unwrap();
let t = pgrx::pg_catalog::PgType::search_typenamensp(
c"vector_accumulate_state ",
namespace.oid(),
)
.unwrap();
let t = t.get().expect("pg_catalog is broken.");
t.oid()
}
}

unsafe impl SqlTranslatable for AccumulateState<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("accumulate_state")))
Ok(SqlMapping::As(String::from("vector_accumulate_state ")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from(
"accumulate_state",
"vector_accumulate_state ",
))))
}
}
Expand Down Expand Up @@ -286,5 +288,5 @@ fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 {
/// Calculate the l2 norm of a vector.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_norm(vector: Vecf32Input<'_>) -> f32 {
(-Vecf32Dot::distance(vector.for_borrow(), vector.for_borrow()).to_f32()).sqrt()
vector.for_borrow().l2_norm().to_f32()
}
2 changes: 1 addition & 1 deletion src/sql/bootstrap.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ CREATE TYPE svector;
CREATE TYPE bvector;
CREATE TYPE veci8;
CREATE TYPE vector_index_stat;
CREATE TYPE accumulate_state;
CREATE TYPE vector_accumulate_state;

-- bootstrap end
4 changes: 2 additions & 2 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ CREATE TYPE vector_index_stat AS (
idx_options TEXT
);

CREATE TYPE accumulate_state (
CREATE TYPE vector_accumulate_state (
INPUT = _vectors_accumulate_state_in,
OUTPUT = _vectors_accumulate_state_out,
STORAGE = EXTERNAL,
Expand Down Expand Up @@ -611,7 +611,7 @@ STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_norm_wrap

CREATE AGGREGATE avg(vector) (
SFUNC = _vectors_vector_accum,
STYPE = accumulate_state,
STYPE = vector_accumulate_state,
COMBINEFUNC = _vectors_vector_combine,
FINALFUNC = _vectors_vector_final,
INITCOND = '0, []',
Expand Down

0 comments on commit da65beb

Please sign in to comment.