diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index f829ab548..8949328c8 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -67,6 +67,8 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ key: ${{ github.job }}-${{ matrix.version }}-${{ matrix.os }}-${{ hashFiles('./Cargo.lock') }} + - name: Set up PostgreSQL + run: ./scripts/ci_setup.sh - name: Set up Python uses: actions/setup-python@v5 with: @@ -75,9 +77,10 @@ jobs: run: curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash - name: Set up Sqllogictest run: cargo binstall sqllogictest-bin -y --force - - name: Set up Environment - shell: bash - run: ./scripts/ci_setup.sh + - name: Set up Pgrx + run: | + cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug + cargo pgrx init --pg$VERSION=$(which pg_config) - name: Release build run: | cargo pgrx install --no-default-features --features "pg$VERSION" --release @@ -125,9 +128,14 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ key: ${{ github.job }}-${{ matrix.version }}-${{ matrix.os }}-${{ hashFiles('./Cargo.lock') }} - - name: Set up Environment - shell: bash + - name: Set up PostgreSQL run: ./scripts/ci_setup.sh + - name: Set up Binstall + run: curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash + - name: Set up Pgrx + run: | + cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug + cargo pgrx init --pg$VERSION=$(which pg_config) - name: Format check run: cargo fmt --check - name: Semantic check diff --git a/Cargo.lock b/Cargo.lock index fa2590971..707e79c3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,9 +82,9 @@ checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "arc-swap" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +checksum = "7b3d0060af21e8d11a926981cc00c6c1541aa91dd64b9f881985c3da1094425f" [[package]] name = "arrayvec" @@ -366,7 +366,7 @@ dependencies = [ "bytemuck", "c", "detect", - "half 2.3.1", + "half 2.4.0", "libc", "multiversion", "num-traits", @@ -531,7 +531,7 @@ version = "0.0.0" dependencies = [ "cc", "detect", - "half 2.3.1", + "half 2.4.0", "rand", ] @@ -569,9 +569,9 @@ dependencies = [ [[package]] name = "cargo_toml" -version = "0.19.1" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dc9f7a067415ab5058020f04c60ec7b557084dbec0e021217bbabc7a8d38d14" +checksum = "a98356df42a2eb1bd8f1793ae4ee4de48e384dd974ce5eac8eee802edb7492be" dependencies = [ "serde", "toml", @@ -579,9 +579,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc" +checksum = "a0ba8f7aaa012f30d5b2861462f6708eccd49c3c39863fe083a308035f63d723" [[package]] name = "cexpr" @@ -794,6 +794,41 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" +[[package]] +name = "darling" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.52", +] + +[[package]] +name = "darling_macro" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.52", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -1226,9 +1261,9 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" [[package]] name = "half" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +checksum = "b5eceaaeec696539ddaf7b333340f1af35a5aa87ae3e4f3ead0532f72affab2e" dependencies = [ "bytemuck", "cfg-if", @@ -1296,9 +1331,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" dependencies = [ "bytes", "fnv", @@ -1401,14 +1436,10 @@ dependencies = [ ] [[package]] -name = "idna" -version = "0.4.0" +name = "ident_case" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" @@ -1420,12 +1451,6 @@ dependencies = [ "unicode-normalization", ] -[[package]] -name = "if_chain" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" - [[package]] name = "indenter" version = "0.3.3" @@ -1502,9 +1527,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" dependencies = [ "wasm-bindgen", ] @@ -1575,12 +1600,12 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +checksum = "2caa5afb8bf9f3a2652760ce7d4f62d21c4d5a423e68466fca30df82f2330164" dependencies = [ "cfg-if", - "windows-sys 0.48.0", + "windows-targets 0.52.4", ] [[package]] @@ -1697,9 +1722,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi", @@ -1847,9 +1872,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.7.7" +version = "2.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219c0dcc30b6a27553f9cc242972b67f75b60eb0db71f0b5462f38b058c41546" +checksum = "56f8023d0fb78c8e03784ea1c7f3fa36e68a723138990b8d5a47d916b651e7a8" dependencies = [ "memchr", "thiserror", @@ -2297,9 +2322,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -2736,7 +2761,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "std_detect" version = "0.1.5" -source = "git+https://github.com/tensorchord/stdarch.git?branch=avx512fp16#db0cdbc9b02074bfddabfd23a4a681f21640eada" +source = "git+https://github.com/tensorchord/stdarch.git?branch=2024-03-04#6d0479dd6afc14650c1a58af9d881d3d837cc3bd" dependencies = [ "cfg-if", "libc", @@ -2766,6 +2791,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -3138,7 +3169,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna 0.5.0", + "idna", "percent-encoding", ] @@ -3160,12 +3191,12 @@ dependencies = [ [[package]] name = "validator" -version = "0.16.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd" +checksum = "da339118f018cc70ebf01fafc103360528aad53717e4bf311db929cb01cb9345" dependencies = [ - "idna 0.4.0", - "lazy_static", + "idna", + "once_cell", "regex", "serde", "serde_derive", @@ -3176,28 +3207,16 @@ dependencies = [ [[package]] name = "validator_derive" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af" +checksum = "76e88ea23b8f5e59230bff8a2f03c0ee0054a61d5b8343a38946bcd406fe624c" dependencies = [ - "if_chain", - "lazy_static", + "darling", "proc-macro-error", "proc-macro2", "quote", "regex", - "syn 1.0.109", - "validator_types", -] - -[[package]] -name = "validator_types" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3" -dependencies = [ - "proc-macro2", - "syn 1.0.109", + "syn 2.0.52", ] [[package]] @@ -3262,9 +3281,9 @@ checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" [[package]] name = "walkdir" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" dependencies = [ "same-file", "winapi-util", @@ -3285,11 +3304,17 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3297,9 +3322,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" dependencies = [ "bumpalo", "log", @@ -3312,9 +3337,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" +checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" dependencies = [ "cfg-if", "js-sys", @@ -3324,9 +3349,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3334,9 +3359,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", @@ -3347,15 +3372,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] name = "web-sys" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" dependencies = [ "js-sys", "wasm-bindgen", @@ -3369,11 +3394,12 @@ checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "whoami" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" +checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e" dependencies = [ - "wasm-bindgen", + "redox_syscall", + "wasite", "web-sys", ] diff --git a/Cargo.toml b/Cargo.toml index e00aeb519..c95f855b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ arrayvec = "~0.7" bincode = "~1.3" bytemuck = { version = "~1.14", features = ["extern_crate_alloc"] } byteorder = "~1.5" -half = { version = "~2.3", features = [ +half = { version = "~2.4", features = [ "bytemuck", "num-traits", "serde", @@ -88,7 +88,7 @@ serde = "~1.0" serde_json = "~1.0" thiserror = "~1.0" uuid = { version = "1.7.0", features = ["v4", "serde"] } -validator = { version = "~0.16", features = ["derive"] } +validator = { version = "~0.17", features = ["derive"] } [workspace.lints] rust.unsafe_op_in_unsafe_fn = "forbid" diff --git a/crates/base/src/global/bvecf32.rs b/crates/base/src/global/bvecf32.rs index a145f93c0..0279d6ee3 100644 --- a/crates/base/src/global/bvecf32.rs +++ b/crates/base/src/global/bvecf32.rs @@ -31,10 +31,26 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq")] - unsafe fn cosine_avx(lhs: &[usize], rhs: &[usize]) -> F32 { + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + unsafe fn cosine_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { + use std::arch::x86_64::*; + #[inline] + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { + let mut dst: __m512i; + unsafe { + std::arch::asm!( + "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", + p = in(reg) mem_addr, + k = in(kreg) k, + dst = out(zmm_reg) dst, + options(pure, readonly, nostack) + ); + } + dst + } + assert_eq!(lhs.len(), rhs.len()); unsafe { - use std::arch::x86_64::*; const WIDTH: usize = 512 / 8 / std::mem::size_of::(); let mut xy = _mm512_setzero_si512(); let mut xx = _mm512_setzero_si512(); @@ -53,7 +69,7 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y)); } if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32).try_into().unwrap(); + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; let x = _mm512_maskz_loadu_epi64(mask, a.cast()); let y = _mm512_maskz_loadu_epi64(mask, b.cast()); xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); @@ -70,7 +86,7 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[cfg(target_arch = "x86_64")] if detect::x86_64::detect_avx512vpopcntdq() { unsafe { - return cosine_avx(lhs, rhs); + return cosine_avx512vpopcntdq(lhs, rhs); } } cosine(lhs, rhs) @@ -98,10 +114,26 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq")] - unsafe fn dot_avx(lhs: &[usize], rhs: &[usize]) -> F32 { + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + unsafe fn dot_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { + use std::arch::x86_64::*; + #[inline] + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { + let mut dst: __m512i; + unsafe { + std::arch::asm!( + "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", + p = in(reg) mem_addr, + k = in(kreg) k, + dst = out(zmm_reg) dst, + options(pure, readonly, nostack) + ); + } + dst + } + assert_eq!(lhs.len(), rhs.len()); unsafe { - use std::arch::x86_64::*; const WIDTH: usize = 512 / 8 / std::mem::size_of::(); let mut xy = _mm512_setzero_si512(); let mut a = lhs.as_ptr(); @@ -116,7 +148,7 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); } if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32).try_into().unwrap(); + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; let x = _mm512_maskz_loadu_epi64(mask, a.cast()); let y = _mm512_maskz_loadu_epi64(mask, b.cast()); xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); @@ -129,7 +161,7 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[cfg(target_arch = "x86_64")] if detect::x86_64::detect_avx512vpopcntdq() { unsafe { - return dot_avx(lhs, rhs); + return dot_avx512vpopcntdq(lhs, rhs); } } dot(lhs, rhs) @@ -157,10 +189,26 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq")] - unsafe fn sl2_avx(lhs: &[usize], rhs: &[usize]) -> F32 { + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + unsafe fn sl2_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { + use std::arch::x86_64::*; + #[inline] + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { + let mut dst: __m512i; + unsafe { + std::arch::asm!( + "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", + p = in(reg) mem_addr, + k = in(kreg) k, + dst = out(zmm_reg) dst, + options(pure, readonly, nostack) + ); + } + dst + } + assert_eq!(lhs.len(), rhs.len()); unsafe { - use std::arch::x86_64::*; const WIDTH: usize = 512 / 8 / std::mem::size_of::(); let mut dd = _mm512_setzero_si512(); let mut a = lhs.as_ptr(); @@ -175,7 +223,7 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); } if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32).try_into().unwrap(); + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; let x = _mm512_maskz_loadu_epi64(mask, a.cast()); let y = _mm512_maskz_loadu_epi64(mask, b.cast()); dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); @@ -188,7 +236,7 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[cfg(target_arch = "x86_64")] if detect::x86_64::detect_avx512vpopcntdq() { unsafe { - return sl2_avx(lhs, rhs); + return sl2_avx512vpopcntdq(lhs, rhs); } } sl2(lhs, rhs) @@ -218,10 +266,26 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq")] - unsafe fn jaccard_avx(lhs: &[usize], rhs: &[usize]) -> F32 { + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + unsafe fn jaccard_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { + use std::arch::x86_64::*; + #[inline] + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { + let mut dst: __m512i; + unsafe { + std::arch::asm!( + "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", + p = in(reg) mem_addr, + k = in(kreg) k, + dst = out(zmm_reg) dst, + options(pure, readonly, nostack) + ); + } + dst + } + assert_eq!(lhs.len(), rhs.len()); unsafe { - use std::arch::x86_64::*; const WIDTH: usize = 512 / 8 / std::mem::size_of::(); let mut inter = _mm512_setzero_si512(); let mut union = _mm512_setzero_si512(); @@ -238,7 +302,7 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); } if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32).try_into().unwrap(); + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; let x = _mm512_maskz_loadu_epi64(mask, a.cast()); let y = _mm512_maskz_loadu_epi64(mask, b.cast()); inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); @@ -253,7 +317,7 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[cfg(target_arch = "x86_64")] if detect::x86_64::detect_avx512vpopcntdq() { unsafe { - return jaccard_avx(lhs, rhs); + return jaccard_avx512vpopcntdq(lhs, rhs); } } jaccard(lhs, rhs) @@ -279,10 +343,25 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { } #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq")] - unsafe fn length_avx(lhs: &[usize]) -> F32 { + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + unsafe fn length_avx512vpopcntdq(lhs: &[usize]) -> F32 { + use std::arch::x86_64::*; + #[inline] + #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { + let mut dst: __m512i; + unsafe { + std::arch::asm!( + "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", + p = in(reg) mem_addr, + k = in(kreg) k, + dst = out(zmm_reg) dst, + options(pure, readonly, nostack) + ); + } + dst + } unsafe { - use std::arch::x86_64::*; const WIDTH: usize = 512 / 8 / std::mem::size_of::(); let mut cnt = _mm512_setzero_si512(); let mut a = lhs.as_ptr(); @@ -294,7 +373,7 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x)); } if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32).try_into().unwrap(); + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; let x = _mm512_maskz_loadu_epi64(mask, a.cast()); cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x)); } @@ -306,7 +385,7 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { #[cfg(target_arch = "x86_64")] if detect::x86_64::detect_avx512vpopcntdq() { unsafe { - return length_avx(vector); + return length_avx512vpopcntdq(vector); } } length(vector) diff --git a/crates/base/src/global/mod.rs b/crates/base/src/global/mod.rs index 56993386e..6a1584f81 100644 --- a/crates/base/src/global/mod.rs +++ b/crates/base/src/global/mod.rs @@ -42,7 +42,7 @@ use crate::scalar::*; use crate::vector::*; pub trait GlobalElkanKMeans: Global { - type VectorNormalized: VectorOwned = Self::VectorOwned; + type VectorNormalized: VectorOwned; fn elkan_k_means_normalize(vector: &mut [Scalar]); fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Self::VectorNormalized; diff --git a/crates/base/src/global/svecf32.rs b/crates/base/src/global/svecf32.rs index cd45491a1..02095e417 100644 --- a/crates/base/src/global/svecf32.rs +++ b/crates/base/src/global/svecf32.rs @@ -1,4 +1,3 @@ -use super::SVecf32Owned; use crate::scalar::*; use crate::vector::*; use num_traits::{Float, Zero}; diff --git a/crates/base/src/global/svecf32_cos.rs b/crates/base/src/global/svecf32_cos.rs index 6906bad94..0a7506ae3 100644 --- a/crates/base/src/global/svecf32_cos.rs +++ b/crates/base/src/global/svecf32_cos.rs @@ -19,6 +19,8 @@ impl Global for SVecf32Cos { } impl GlobalElkanKMeans for SVecf32Cos { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [Scalar]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/global/svecf32_dot.rs b/crates/base/src/global/svecf32_dot.rs index ad6c16271..8cdf80c38 100644 --- a/crates/base/src/global/svecf32_dot.rs +++ b/crates/base/src/global/svecf32_dot.rs @@ -19,6 +19,8 @@ impl Global for SVecf32Dot { } impl GlobalElkanKMeans for SVecf32Dot { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [Scalar]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/global/svecf32_l2.rs b/crates/base/src/global/svecf32_l2.rs index 85d428560..805c207d5 100644 --- a/crates/base/src/global/svecf32_l2.rs +++ b/crates/base/src/global/svecf32_l2.rs @@ -19,6 +19,8 @@ impl Global for SVecf32L2 { } impl GlobalElkanKMeans for SVecf32L2 { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(_: &mut [Scalar]) {} fn elkan_k_means_normalize2(vector: SVecf32Borrowed<'_>) -> SVecf32Owned { diff --git a/crates/base/src/global/vecf16_cos.rs b/crates/base/src/global/vecf16_cos.rs index 2915e0858..1e51f7cd7 100644 --- a/crates/base/src/global/vecf16_cos.rs +++ b/crates/base/src/global/vecf16_cos.rs @@ -19,6 +19,8 @@ impl Global for Vecf16Cos { } impl GlobalElkanKMeans for Vecf16Cos { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [F16]) { super::vecf16::l2_normalize(vector) } diff --git a/crates/base/src/global/vecf16_dot.rs b/crates/base/src/global/vecf16_dot.rs index f45630ae7..aa6c24518 100644 --- a/crates/base/src/global/vecf16_dot.rs +++ b/crates/base/src/global/vecf16_dot.rs @@ -19,6 +19,8 @@ impl Global for Vecf16Dot { } impl GlobalElkanKMeans for Vecf16Dot { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [F16]) { super::vecf16::l2_normalize(vector) } diff --git a/crates/base/src/global/vecf16_l2.rs b/crates/base/src/global/vecf16_l2.rs index a40b6af71..03aa8d1e0 100644 --- a/crates/base/src/global/vecf16_l2.rs +++ b/crates/base/src/global/vecf16_l2.rs @@ -19,6 +19,8 @@ impl Global for Vecf16L2 { } impl GlobalElkanKMeans for Vecf16L2 { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(_: &mut [F16]) {} fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned { diff --git a/crates/base/src/global/vecf32_cos.rs b/crates/base/src/global/vecf32_cos.rs index 2086aa6a8..8a49037be 100644 --- a/crates/base/src/global/vecf32_cos.rs +++ b/crates/base/src/global/vecf32_cos.rs @@ -19,6 +19,8 @@ impl Global for Vecf32Cos { } impl GlobalElkanKMeans for Vecf32Cos { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [F32]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/global/vecf32_dot.rs b/crates/base/src/global/vecf32_dot.rs index b62a853d2..0a7ad6ccc 100644 --- a/crates/base/src/global/vecf32_dot.rs +++ b/crates/base/src/global/vecf32_dot.rs @@ -19,6 +19,8 @@ impl Global for Vecf32Dot { } impl GlobalElkanKMeans for Vecf32Dot { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [F32]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/global/vecf32_l2.rs b/crates/base/src/global/vecf32_l2.rs index bfec14d9e..7d994434d 100644 --- a/crates/base/src/global/vecf32_l2.rs +++ b/crates/base/src/global/vecf32_l2.rs @@ -19,6 +19,8 @@ impl Global for Vecf32L2 { } impl GlobalElkanKMeans for Vecf32L2 { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(_: &mut [F32]) {} fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned { diff --git a/crates/base/src/global/veci8.rs b/crates/base/src/global/veci8.rs index a2622fb39..e5e3c00e6 100644 --- a/crates/base/src/global/veci8.rs +++ b/crates/base/src/global/veci8.rs @@ -31,10 +31,24 @@ fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 { } #[cfg(target_arch = "x86_64")] -#[target_feature(enable = "avx512f,avx512bw,avx512vnni,bmi2")] +#[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")] unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 { use std::arch::x86_64::*; - + #[inline] + #[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")] + pub unsafe fn _mm512_maskz_loadu_epi8(k: __mmask64, mem_addr: *const i8) -> __m512i { + let mut dst: __m512i; + unsafe { + std::arch::asm!( + "vmovdqu8 {dst}{{{k}}} {{z}}, [{p}]", + p = in(reg) mem_addr, + k = in(kreg) k, + dst = out(zmm_reg) dst, + options(pure, readonly, nostack) + ); + } + dst + } assert_eq!(x.len(), y.len()); let mut sum = 0; let mut i = x.len(); @@ -186,7 +200,7 @@ mod tests { let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); let result = cosine_distance(&ref_x, &ref_y); - assert!((result.0 - result_expected).abs() / result_expected < 0.05); + assert!((result.0 - result_expected).abs() / result_expected < 0.25); } #[test] diff --git a/crates/base/src/global/veci8_cos.rs b/crates/base/src/global/veci8_cos.rs index f27a24dfb..9bf6dcb8d 100644 --- a/crates/base/src/global/veci8_cos.rs +++ b/crates/base/src/global/veci8_cos.rs @@ -19,6 +19,8 @@ impl Global for Veci8Cos { } impl GlobalElkanKMeans for Veci8Cos { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [Scalar]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/global/veci8_dot.rs b/crates/base/src/global/veci8_dot.rs index a8fba8892..40d0b7749 100644 --- a/crates/base/src/global/veci8_dot.rs +++ b/crates/base/src/global/veci8_dot.rs @@ -19,6 +19,8 @@ impl Global for Veci8Dot { } impl GlobalElkanKMeans for Veci8Dot { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [Scalar]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/global/veci8_l2.rs b/crates/base/src/global/veci8_l2.rs index 966be1ba6..c2bca65e4 100644 --- a/crates/base/src/global/veci8_l2.rs +++ b/crates/base/src/global/veci8_l2.rs @@ -19,6 +19,8 @@ impl Global for Veci8L2 { } impl GlobalElkanKMeans for Veci8L2 { + type VectorNormalized = Self::VectorOwned; + fn elkan_k_means_normalize(vector: &mut [Scalar]) { super::vecf32::l2_normalize(vector) } diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs index c2f3ffdd3..d00ec89dd 100644 --- a/crates/base/src/index.rs +++ b/crates/base/src/index.rs @@ -78,10 +78,10 @@ impl VectorOptions { #[validate(schema(function = "Self::validate_0"))] pub struct SegmentsOptions { #[serde(default = "SegmentsOptions::default_max_growing_segment_size")] - #[validate(range(min = 1, max = 4_000_000_000))] + #[validate(range(min = 1, max = 4_000_000_000u32))] pub max_growing_segment_size: u32, #[serde(default = "SegmentsOptions::default_max_sealed_segment_size")] - #[validate(range(min = 1, max = 4_000_000_000))] + #[validate(range(min = 1, max = 4_000_000_000u32))] pub max_sealed_segment_size: u32, } @@ -119,7 +119,7 @@ pub struct OptimizingOptions { #[validate(range(min = 1, max = 60))] pub sealing_secs: u64, #[serde(default = "OptimizingOptions::default_sealing_size")] - #[validate(range(min = 1, max = 4_000_000_000))] + #[validate(range(min = 1, max = 4_000_000_000u32))] pub sealing_size: u32, #[serde(default = "OptimizingOptions::default_delete_threshold")] #[validate(range(min = 0.01, max = 1.00))] diff --git a/crates/base/src/lib.rs b/crates/base/src/lib.rs index 411e2dec3..a971ab694 100644 --- a/crates/base/src/lib.rs +++ b/crates/base/src/lib.rs @@ -1,7 +1,6 @@ #![feature(core_intrinsics)] #![feature(avx512_target_feature)] -#![feature(associated_type_defaults)] -#![feature(stdsimd)] +#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![allow(internal_features)] #![allow(clippy::derivable_impls)] #![allow(clippy::len_without_is_empty)] diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml index e2cd91c15..f1f411451 100644 --- a/crates/detect/Cargo.toml +++ b/crates/detect/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true [dependencies] rustix.workspace = true -std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } +std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "2024-03-04" } [lints] workspace = true diff --git a/crates/detect/src/x86_64.rs b/crates/detect/src/x86_64.rs index f1367a171..103718940 100644 --- a/crates/detect/src/x86_64.rs +++ b/crates/detect/src/x86_64.rs @@ -8,7 +8,7 @@ pub fn test_avx512fp16() -> bool { } pub fn test_avx512vpopcntdq() -> bool { - std_detect::is_x86_feature_detected!("avx512vpopcntdq") && test_v4() + std::is_x86_feature_detected!("avx512vpopcntdq") && test_v4() } pub fn ctor_avx512fp16() { @@ -30,11 +30,11 @@ pub fn detect_avx512vpopcntdq() -> bool { static ATOMIC_V4: AtomicBool = AtomicBool::new(false); pub fn test_v4() -> bool { - std_detect::is_x86_feature_detected!("avx512bw") - && std_detect::is_x86_feature_detected!("avx512cd") - && std_detect::is_x86_feature_detected!("avx512dq") - && std_detect::is_x86_feature_detected!("avx512f") - && std_detect::is_x86_feature_detected!("avx512vl") + std::is_x86_feature_detected!("avx512bw") + && std::is_x86_feature_detected!("avx512cd") + && std::is_x86_feature_detected!("avx512dq") + && std::is_x86_feature_detected!("avx512f") + && std::is_x86_feature_detected!("avx512vl") && test_v3() } @@ -49,15 +49,15 @@ pub fn detect_v4() -> bool { static ATOMIC_V3: AtomicBool = AtomicBool::new(false); pub fn test_v3() -> bool { - std_detect::is_x86_feature_detected!("avx") - && std_detect::is_x86_feature_detected!("avx2") - && std_detect::is_x86_feature_detected!("bmi1") - && std_detect::is_x86_feature_detected!("bmi2") - && std_detect::is_x86_feature_detected!("f16c") - && std_detect::is_x86_feature_detected!("fma") - && std_detect::is_x86_feature_detected!("lzcnt") - && std_detect::is_x86_feature_detected!("movbe") - && std_detect::is_x86_feature_detected!("xsave") + std::is_x86_feature_detected!("avx") + && std::is_x86_feature_detected!("avx2") + && std::is_x86_feature_detected!("bmi1") + && std::is_x86_feature_detected!("bmi2") + && std::is_x86_feature_detected!("f16c") + && std::is_x86_feature_detected!("fma") + && std::is_x86_feature_detected!("lzcnt") + && std::is_x86_feature_detected!("movbe") + && std::is_x86_feature_detected!("xsave") && test_v2() } @@ -72,15 +72,15 @@ pub fn detect_v3() -> bool { static ATOMIC_V2: AtomicBool = AtomicBool::new(false); pub fn test_v2() -> bool { - std_detect::is_x86_feature_detected!("cmpxchg16b") - && std_detect::is_x86_feature_detected!("fxsr") - && std_detect::is_x86_feature_detected!("popcnt") - && std_detect::is_x86_feature_detected!("sse") - && std_detect::is_x86_feature_detected!("sse2") - && std_detect::is_x86_feature_detected!("sse3") - && std_detect::is_x86_feature_detected!("sse4.1") - && std_detect::is_x86_feature_detected!("sse4.2") - && std_detect::is_x86_feature_detected!("ssse3") + std::is_x86_feature_detected!("cmpxchg16b") + && std::is_x86_feature_detected!("fxsr") + && std::is_x86_feature_detected!("popcnt") + && std::is_x86_feature_detected!("sse") + && std::is_x86_feature_detected!("sse2") + && std::is_x86_feature_detected!("sse3") + && std::is_x86_feature_detected!("sse4.1") + && std::is_x86_feature_detected!("sse4.2") + && std::is_x86_feature_detected!("ssse3") } pub fn ctor_v2() { @@ -95,7 +95,7 @@ static ATOMIC_AVX512VNNI: AtomicBool = AtomicBool::new(false); /// check if the CPU supports avx512vnni pub fn test_avx512vnni() -> bool { - std_detect::is_x86_feature_detected!("avx512vnni") && test_v4() + std::is_x86_feature_detected!("avx512vnni") && test_v4() } pub fn ctor_vnni() { diff --git a/crates/service/src/algorithms/quantization/product.rs b/crates/service/src/algorithms/quantization/product.rs index 37cd6d4b1..b27fae135 100644 --- a/crates/service/src/algorithms/quantization/product.rs +++ b/crates/service/src/algorithms/quantization/product.rs @@ -1,6 +1,5 @@ use crate::algorithms::clustering::elkan_k_means::ElkanKMeans; use crate::algorithms::quantization::Quan; -use crate::algorithms::quantization::QuantizationOptions; use crate::algorithms::raw::Raw; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; diff --git a/crates/service/src/algorithms/quantization/scalar.rs b/crates/service/src/algorithms/quantization/scalar.rs index 033ea9df3..6cb40b489 100644 --- a/crates/service/src/algorithms/quantization/scalar.rs +++ b/crates/service/src/algorithms/quantization/scalar.rs @@ -1,5 +1,4 @@ use crate::algorithms::quantization::Quan; -use crate::algorithms::quantization::QuantizationOptions; use crate::algorithms::raw::Raw; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; diff --git a/crates/service/src/algorithms/quantization/trivial.rs b/crates/service/src/algorithms/quantization/trivial.rs index 2350fa62e..abac18002 100644 --- a/crates/service/src/algorithms/quantization/trivial.rs +++ b/crates/service/src/algorithms/quantization/trivial.rs @@ -1,5 +1,4 @@ use crate::algorithms::quantization::Quan; -use crate::algorithms::quantization::QuantizationOptions; use crate::algorithms::raw::Raw; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; diff --git a/crates/service/src/index/indexing/flat.rs b/crates/service/src/index/indexing/flat.rs index f6660a21a..354ab4a33 100644 --- a/crates/service/src/index/indexing/flat.rs +++ b/crates/service/src/index/indexing/flat.rs @@ -1,7 +1,5 @@ use super::AbstractIndexing; use crate::index::segments::growing::GrowingSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; use crate::prelude::*; use crate::{algorithms::flat::Flat, index::segments::sealed::SealedSegment}; use std::cmp::Reverse; diff --git a/crates/service/src/index/indexing/hnsw.rs b/crates/service/src/index/indexing/hnsw.rs index a99710bde..9c8bb482c 100644 --- a/crates/service/src/index/indexing/hnsw.rs +++ b/crates/service/src/index/indexing/hnsw.rs @@ -2,8 +2,6 @@ use super::AbstractIndexing; use crate::algorithms::hnsw::Hnsw; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; use crate::prelude::*; use std::cmp::Reverse; use std::collections::BinaryHeap; diff --git a/crates/service/src/index/indexing/ivf.rs b/crates/service/src/index/indexing/ivf.rs index 85c06c7d5..94a3ed362 100644 --- a/crates/service/src/index/indexing/ivf.rs +++ b/crates/service/src/index/indexing/ivf.rs @@ -2,8 +2,6 @@ use super::AbstractIndexing; use crate::algorithms::ivf::Ivf; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; use crate::prelude::*; use std::cmp::Reverse; use std::collections::BinaryHeap; diff --git a/crates/service/src/index/indexing/mod.rs b/crates/service/src/index/indexing/mod.rs index 111f5a957..b2d625a18 100644 --- a/crates/service/src/index/indexing/mod.rs +++ b/crates/service/src/index/indexing/mod.rs @@ -7,8 +7,6 @@ use self::hnsw::HnswIndexing; use self::ivf::IvfIndexing; use super::segments::growing::GrowingSegment; use super::segments::sealed::SealedSegment; -use super::IndexOptions; -use crate::index::SearchOptions; use crate::prelude::*; use std::cmp::Reverse; use std::collections::BinaryHeap; diff --git a/crates/service/src/index/segments/growing.rs b/crates/service/src/index/segments/growing.rs index 65cf5647c..5a5cbcd88 100644 --- a/crates/service/src/index/segments/growing.rs +++ b/crates/service/src/index/segments/growing.rs @@ -1,8 +1,5 @@ use super::SegmentTracker; -use crate::index::IndexOptions; use crate::index::IndexTracker; -use crate::index::SearchOptions; -use crate::index::SegmentStat; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::file_wal::FileWal; diff --git a/crates/service/src/index/segments/sealed.rs b/crates/service/src/index/segments/sealed.rs index 137a33d71..b68f25eb7 100644 --- a/crates/service/src/index/segments/sealed.rs +++ b/crates/service/src/index/segments/sealed.rs @@ -1,7 +1,7 @@ use super::growing::GrowingSegment; use super::SegmentTracker; use crate::index::indexing::DynamicIndexing; -use crate::index::{IndexOptions, IndexTracker, SearchOptions, SegmentStat}; +use crate::index::IndexTracker; use crate::prelude::*; use crate::utils::dir_ops::{dir_size, sync_dir}; use std::cmp::Reverse; diff --git a/crates/service/src/lib.rs b/crates/service/src/lib.rs index dae0d7b93..0d20b972c 100644 --- a/crates/service/src/lib.rs +++ b/crates/service/src/lib.rs @@ -1,5 +1,5 @@ #![allow(clippy::needless_range_loop)] -#![feature(stdsimd)] +#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] mod algorithms; mod index; diff --git a/crates/service/src/storage/mod.rs b/crates/service/src/storage/mod.rs index d741eea40..ec3406688 100644 --- a/crates/service/src/storage/mod.rs +++ b/crates/service/src/storage/mod.rs @@ -15,6 +15,7 @@ use std::path::Path; pub trait Storage { type VectorOwned: VectorOwned; + #[allow(unused)] fn dims(&self) -> u32; fn len(&self) -> u32; fn vector(&self, i: u32) -> ::Borrowed<'_>; diff --git a/crates/service/src/utils/file_wal.rs b/crates/service/src/utils/file_wal.rs index d487456ab..10f3a9707 100644 --- a/crates/service/src/utils/file_wal.rs +++ b/crates/service/src/utils/file_wal.rs @@ -36,6 +36,7 @@ impl FileWal { .create(true) .write(true) .read(true) + .truncate(false) .open(path) .expect("Failed to open wal."); Self { diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f86a74a98..89c656b5f 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-01-14" +channel = "nightly-2024-03-04" profile = "default" targets = [ "aarch64-apple-darwin", diff --git a/scripts/build_2.sh b/scripts/build_2.sh index 6cd014fc9..2c08acf53 100755 --- a/scripts/build_2.sh +++ b/scripts/build_2.sh @@ -8,7 +8,7 @@ printf "ARCH = ${ARCH}\n" export PLATFORM=$(echo $ARCH | sed 's/aarch64/arm64/; s/x86_64/amd64/') cargo build --release --no-default-features --features pg$VERSION --target ${ARCH}-unknown-linux-gnu -cargo pgrx schema --no-default-features --features pg$VERSION > ./target/vectors--$SEMVER.sql +cargo pgrx schema --no-default-features --features pg$VERSION | expand -t 4 > ./target/vectors--$SEMVER.sql rm -rf ./build/dir_zip rm -rf ./build/vectors-pg${VERSION}_${ARCH}-unknown-linux-gnu_${SEMVER}.zip diff --git a/scripts/ci_setup.sh b/scripts/ci_setup.sh index 43e092f81..03d2d3196 100755 --- a/scripts/ci_setup.sh +++ b/scripts/ci_setup.sh @@ -28,6 +28,3 @@ fi sudo chmod -R 777 `pg_config --pkglibdir` sudo chmod -R 777 `pg_config --sharedir`/extension - -cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug -cargo pgrx init --pg$VERSION=$(which pg_config) diff --git a/src/bgworker/normal.rs b/src/bgworker/normal.rs index 8d722d139..0caae237e 100644 --- a/src/bgworker/normal.rs +++ b/src/bgworker/normal.rs @@ -1,5 +1,6 @@ use crate::ipc::ConnectionError; -use crate::ipc::ServerRpcHandler; +use crate::ipc::{listen_mmap, listen_unix}; +use crate::ipc::{ServerRpcHandle, ServerRpcHandler}; use service::Worker; use std::convert::Infallible; use std::sync::Arc; @@ -9,7 +10,7 @@ pub fn normal(worker: Arc) { scope.spawn({ let worker = worker.clone(); move || { - for rpc_handler in crate::ipc::listen_unix() { + for rpc_handler in listen_unix() { let worker = worker.clone(); std::thread::spawn({ move || { @@ -24,7 +25,7 @@ pub fn normal(worker: Arc) { scope.spawn({ let worker = worker.clone(); move || { - for rpc_handler in crate::ipc::listen_mmap() { + for rpc_handler in listen_mmap() { let worker = worker.clone(); std::thread::spawn({ move || { @@ -59,7 +60,6 @@ pub fn normal(worker: Arc) { } fn session(worker: Arc, handler: ServerRpcHandler) -> Result { - use crate::ipc::ServerRpcHandle; use base::worker::*; let mut handler = handler; loop { diff --git a/src/datatype/binary.rs b/src/datatype/binary.rs new file mode 100644 index 000000000..66337a897 --- /dev/null +++ b/src/datatype/binary.rs @@ -0,0 +1,36 @@ +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::{bytea, Datum, Oid}; +use pgrx::pgrx_sql_entity_graph::metadata::*; + +#[repr(transparent)] +pub struct Bytea(*mut bytea); + +impl Bytea { + pub fn new(x: *mut bytea) -> Self { + Self(x) + } +} + +impl IntoDatum for Bytea { + fn into_datum(self) -> Option { + if !self.0.is_null() { + Some(pgrx::pg_sys::Datum::from(self.0)) + } else { + None + } + } + + fn type_oid() -> Oid { + pgrx::pg_sys::BYTEAOID + } +} + +unsafe impl SqlTranslatable for Bytea { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("bytea"))) + } + + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("bytea")))) + } +} diff --git a/src/datatype/binary_bvecf32.rs b/src/datatype/binary_bvecf32.rs index 54a1fa4ab..71902d54b 100644 --- a/src/datatype/binary_bvecf32.rs +++ b/src/datatype/binary_bvecf32.rs @@ -1,15 +1,14 @@ +use super::binary::Bytea; use super::memory_bvecf32::BVecf32Input; use super::memory_bvecf32::BVecf32Output; use base::vector::BVecf32Borrowed; use base::vector::BVEC_WIDTH; +use pgrx::datum::Internal; use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; use pgrx::pg_sys::Oid; -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_bvecf32_send(bvector) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_bvecf32_send(vector: BVecf32Input<'_>) -> Datum { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_bvecf32_send(vector: BVecf32Input<'_>) -> Bytea { use pgrx::pg_sys::StringInfoData; unsafe { let mut buf = StringInfoData::default(); @@ -18,14 +17,12 @@ fn _vectors_bvecf32_send(vector: BVecf32Input<'_>) -> Datum { pgrx::pg_sys::pq_begintypsend(&mut buf); pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2); pgrx::pg_sys::pq_sendbytes(&mut buf, vector.data().as_ptr() as _, bytes as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + Bytea::new(pgrx::pg_sys::pq_endtypsend(&mut buf)) } } -#[pgrx::pg_extern(sql = " -CREATE FUNCTION _vectors_bvecf32_recv(internal, oid, integer) RETURNS bvector -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_bvecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> BVecf32Output { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_bvecf32_recv(internal: Internal, _oid: Oid, _typmod: i32) -> BVecf32Output { use pgrx::pg_sys::StringInfo; unsafe { let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); diff --git a/src/datatype/binary_svecf32.rs b/src/datatype/binary_svecf32.rs index 3b91301aa..7f2936516 100644 --- a/src/datatype/binary_svecf32.rs +++ b/src/datatype/binary_svecf32.rs @@ -1,16 +1,15 @@ +use super::binary::Bytea; use super::memory_svecf32::SVecf32Input; use super::memory_svecf32::SVecf32Output; use base::scalar::F32; use base::vector::SVecf32Borrowed; +use pgrx::datum::Internal; use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; use pgrx::pg_sys::Oid; use std::ffi::c_char; -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_svecf32_send(svector) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_svecf32_send(vector: SVecf32Input<'_>) -> Datum { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_svecf32_send(vector: SVecf32Input<'_>) -> Bytea { use pgrx::pg_sys::StringInfoData; unsafe { let mut buf = StringInfoData::default(); @@ -24,14 +23,12 @@ fn _vectors_svecf32_send(vector: SVecf32Input<'_>) -> Datum { pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u32 as _, 4); pgrx::pg_sys::pq_sendbytes(&mut buf, x.indexes().as_ptr() as _, b_indexes as _); pgrx::pg_sys::pq_sendbytes(&mut buf, x.values().as_ptr() as _, b_values as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + Bytea::new(pgrx::pg_sys::pq_endtypsend(&mut buf)) } } -#[pgrx::pg_extern(sql = " -CREATE FUNCTION _vectors_svecf32_recv(internal, oid, integer) RETURNS svector -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_svecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> SVecf32Output { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_svecf32_recv(internal: Internal, _oid: Oid, _typmod: i32) -> SVecf32Output { use pgrx::pg_sys::StringInfo; unsafe { let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); diff --git a/src/datatype/binary_vecf16.rs b/src/datatype/binary_vecf16.rs index 3c7aa7efa..45d044d27 100644 --- a/src/datatype/binary_vecf16.rs +++ b/src/datatype/binary_vecf16.rs @@ -1,14 +1,14 @@ +use super::binary::Bytea; use super::memory_vecf16::{Vecf16Input, Vecf16Output}; use base::scalar::F16; use base::vector::Vecf16Borrowed; +use pgrx::datum::Internal; use pgrx::datum::IntoDatum; -use pgrx::pg_sys::{Datum, Oid}; +use pgrx::pg_sys::Oid; use std::ffi::c_char; -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf16_send(vecf16) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_send(vector: Vecf16Input<'_>) -> Datum { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf16_send(vector: Vecf16Input<'_>) -> Bytea { use pgrx::pg_sys::StringInfoData; unsafe { let mut buf = StringInfoData::default(); @@ -17,14 +17,12 @@ fn _vectors_vecf16_send(vector: Vecf16Input<'_>) -> Datum { pgrx::pg_sys::pq_begintypsend(&mut buf); pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); pgrx::pg_sys::pq_sendbytes(&mut buf, vector.slice().as_ptr() as _, b_slice as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + Bytea::new(pgrx::pg_sys::pq_endtypsend(&mut buf)) } } -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf16_recv(internal, oid, integer) RETURNS vecf16 -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Vecf16Output { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf16_recv(internal: Internal, _oid: Oid, _typmod: i32) -> Vecf16Output { use pgrx::pg_sys::StringInfo; unsafe { let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); diff --git a/src/datatype/binary_vecf32.rs b/src/datatype/binary_vecf32.rs index 4af913f3b..8955d7951 100644 --- a/src/datatype/binary_vecf32.rs +++ b/src/datatype/binary_vecf32.rs @@ -1,14 +1,14 @@ +use super::binary::Bytea; use super::memory_vecf32::{Vecf32Input, Vecf32Output}; use base::scalar::F32; use base::vector::Vecf32Borrowed; +use pgrx::datum::Internal; use pgrx::datum::IntoDatum; -use pgrx::pg_sys::{Datum, Oid}; +use pgrx::pg_sys::Oid; use std::ffi::c_char; -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf32_send(vector) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_send(vector: Vecf32Input<'_>) -> Datum { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf32_send(vector: Vecf32Input<'_>) -> Bytea { use pgrx::pg_sys::StringInfoData; unsafe { let mut buf = StringInfoData::default(); @@ -17,14 +17,12 @@ fn _vectors_vecf32_send(vector: Vecf32Input<'_>) -> Datum { pgrx::pg_sys::pq_begintypsend(&mut buf); pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); pgrx::pg_sys::pq_sendbytes(&mut buf, vector.slice().as_ptr() as _, b_slice as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + Bytea::new(pgrx::pg_sys::pq_endtypsend(&mut buf)) } } -#[pgrx::pg_extern(sql = " -CREATE FUNCTION _vectors_vecf32_recv(internal, oid, integer) RETURNS vector -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Vecf32Output { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf32_recv(internal: Internal, _oid: Oid, _typmod: i32) -> Vecf32Output { use pgrx::pg_sys::StringInfo; unsafe { let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); diff --git a/src/datatype/binary_veci8.rs b/src/datatype/binary_veci8.rs index 26aae0b32..888fcf261 100644 --- a/src/datatype/binary_veci8.rs +++ b/src/datatype/binary_veci8.rs @@ -1,14 +1,14 @@ +use super::binary::Bytea; use super::memory_veci8::{Veci8Input, Veci8Output}; use base::scalar::{F32, I8}; use base::vector::Veci8Borrowed; +use pgrx::datum::Internal; use pgrx::datum::IntoDatum; -use pgrx::pg_sys::{Datum, Oid}; +use pgrx::pg_sys::Oid; use std::ffi::c_char; -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_veci8_send(veci8) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE C AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_veci8_send(vector: Veci8Input<'_>) -> Datum { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_veci8_send(vector: Veci8Input<'_>) -> Bytea { use pgrx::pg_sys::StringInfoData; unsafe { let mut buf = StringInfoData::default(); @@ -25,14 +25,12 @@ fn _vectors_veci8_send(vector: Veci8Input<'_>) -> Datum { pgrx::pg_sys::pq_sendbytes(&mut buf, (&sum) as *const F32 as _, 4); pgrx::pg_sys::pq_sendbytes(&mut buf, (&l2_norm) as *const F32 as _, 4); pgrx::pg_sys::pq_sendbytes(&mut buf, vector.data().as_ptr() as _, bytes as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + Bytea::new(pgrx::pg_sys::pq_endtypsend(&mut buf)) } } -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_veci8_recv(internal, oid, integer) RETURNS veci8 -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE C AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_veci8_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Veci8Output { +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_veci8_recv(internal: Internal, _oid: Oid, _typmod: i32) -> Veci8Output { use pgrx::pg_sys::StringInfo; unsafe { let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); diff --git a/src/datatype/functions_bvecf32.rs b/src/datatype/functions_bvecf32.rs new file mode 100644 index 000000000..82dbdb4aa --- /dev/null +++ b/src/datatype/functions_bvecf32.rs @@ -0,0 +1,15 @@ +use super::memory_bvecf32::BVecf32Output; +use super::memory_vecf32::Vecf32Input; +use crate::prelude::*; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_binarize(vector: Vecf32Input<'_>) -> BVecf32Output { + let mut values = BVecf32Owned::new_zeroed(vector.len() as u16); + for (i, &F32(x)) in vector.slice().iter().enumerate() { + if x > 0. { + values.set(i, true); + } + } + + BVecf32Output::new(values.for_borrow()) +} diff --git a/src/datatype/functions.rs b/src/datatype/functions_svecf32.rs similarity index 73% rename from src/datatype/functions.rs rename to src/datatype/functions_svecf32.rs index d2caecd6b..c31a3c5f9 100644 --- a/src/datatype/functions.rs +++ b/src/datatype/functions_svecf32.rs @@ -1,9 +1,5 @@ -use super::memory_bvecf32::BVecf32Output; use super::memory_svecf32::SVecf32Output; -use super::memory_vecf32::Vecf32Input; use crate::prelude::*; -use base::scalar::F32; -use base::vector::SVecf32Borrowed; #[pgrx::pg_extern(immutable, parallel_safe, strict)] fn _vectors_to_svector( @@ -45,15 +41,3 @@ fn _vectors_to_svector( } SVecf32Output::new(SVecf32Borrowed::new(dims.get(), &indexes, &values)) } - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_binarize(vector: Vecf32Input<'_>) -> BVecf32Output { - let mut values = BVecf32Owned::new_zeroed(vector.len() as u16); - for (i, &F32(x)) in vector.slice().iter().enumerate() { - if x > 0. { - values.set(i, true); - } - } - - BVecf32Output::new(values.for_borrow()) -} diff --git a/src/datatype/functions_veci8.rs b/src/datatype/functions_veci8.rs index 64c447872..4737fb529 100644 --- a/src/datatype/functions_veci8.rs +++ b/src/datatype/functions_veci8.rs @@ -1,6 +1,5 @@ use crate::datatype::memory_veci8::Veci8Output; use crate::prelude::*; -use base::vector::Veci8Borrowed; #[pgrx::pg_extern(immutable, parallel_safe, strict)] fn _vectors_to_veci8(len: i32, alpha: f32, offset: f32, values: pgrx::Array) -> Veci8Output { diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs index f754465cd..abfa165ff 100644 --- a/src/datatype/mod.rs +++ b/src/datatype/mod.rs @@ -1,10 +1,12 @@ +pub mod binary; pub mod binary_bvecf32; pub mod binary_svecf32; pub mod binary_vecf16; pub mod binary_vecf32; pub mod binary_veci8; pub mod casts; -pub mod functions; +pub mod functions_bvecf32; +pub mod functions_svecf32; pub mod functions_veci8; pub mod memory_bvecf32; pub mod memory_svecf32; diff --git a/src/datatype/subscript_bvecf32.rs b/src/datatype/subscript_bvecf32.rs index 38da79b82..e61528782 100644 --- a/src/datatype/subscript_bvecf32.rs +++ b/src/datatype/subscript_bvecf32.rs @@ -1,12 +1,13 @@ use crate::datatype::memory_bvecf32::{BVecf32Input, BVecf32Output}; use base::vector::{BVecf32Owned, VectorOwned, BVEC_WIDTH}; use pgrx::datum::FromDatum; +use pgrx::datum::Internal; use pgrx::pg_sys::Datum; #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_bvecf32_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_bvecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { +fn _vectors_bvecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { #[pgrx::pg_guard] unsafe extern "C" fn transform( subscript: *mut pgrx::pg_sys::SubscriptingRef, @@ -207,5 +208,5 @@ fn _vectors_bvecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum fetch_leakproof: false, store_leakproof: false, }; - std::ptr::addr_of!(SBSROUTINES).into() + Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) } diff --git a/src/datatype/subscript_svecf32.rs b/src/datatype/subscript_svecf32.rs index 51fe10c06..c78cb72cb 100644 --- a/src/datatype/subscript_svecf32.rs +++ b/src/datatype/subscript_svecf32.rs @@ -1,12 +1,13 @@ use crate::datatype::memory_svecf32::{SVecf32Input, SVecf32Output}; use base::vector::SVecf32Borrowed; use pgrx::datum::FromDatum; +use pgrx::datum::Internal; use pgrx::pg_sys::Datum; #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_svecf32_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { +fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { #[pgrx::pg_guard] unsafe extern "C" fn transform( subscript: *mut pgrx::pg_sys::SubscriptingRef, @@ -198,5 +199,5 @@ fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum fetch_leakproof: false, store_leakproof: false, }; - std::ptr::addr_of!(SBSROUTINES).into() + Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) } diff --git a/src/datatype/subscript_vecf16.rs b/src/datatype/subscript_vecf16.rs index eca3afd6d..3ef53ff67 100644 --- a/src/datatype/subscript_vecf16.rs +++ b/src/datatype/subscript_vecf16.rs @@ -1,12 +1,13 @@ use crate::datatype::memory_vecf16::{Vecf16Input, Vecf16Output}; use base::vector::Vecf16Borrowed; use pgrx::datum::FromDatum; +use pgrx::datum::Internal; use pgrx::pg_sys::Datum; #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_vecf16_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { +fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { #[pgrx::pg_guard] unsafe extern "C" fn transform( subscript: *mut pgrx::pg_sys::SubscriptingRef, @@ -181,5 +182,5 @@ fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { fetch_leakproof: false, store_leakproof: false, }; - std::ptr::addr_of!(SBSROUTINES).into() + Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) } diff --git a/src/datatype/subscript_vecf32.rs b/src/datatype/subscript_vecf32.rs index 58331ab33..713c727e1 100644 --- a/src/datatype/subscript_vecf32.rs +++ b/src/datatype/subscript_vecf32.rs @@ -1,12 +1,13 @@ use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output}; use base::vector::Vecf32Borrowed; use pgrx::datum::FromDatum; +use pgrx::datum::Internal; use pgrx::pg_sys::Datum; #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_vecf32_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { +fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { #[pgrx::pg_guard] unsafe extern "C" fn transform( subscript: *mut pgrx::pg_sys::SubscriptingRef, @@ -181,5 +182,5 @@ fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { fetch_leakproof: false, store_leakproof: false, }; - std::ptr::addr_of!(SBSROUTINES).into() + Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) } diff --git a/src/datatype/subscript_veci8.rs b/src/datatype/subscript_veci8.rs index 98a954d6a..39c1daba3 100644 --- a/src/datatype/subscript_veci8.rs +++ b/src/datatype/subscript_veci8.rs @@ -1,12 +1,13 @@ use crate::datatype::memory_veci8::{Veci8Input, Veci8Output}; use base::vector::Veci8Borrowed; use pgrx::datum::FromDatum; +use pgrx::datum::Internal; use pgrx::pg_sys::Datum; #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_veci8_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_veci8_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { +fn _vectors_veci8_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { #[pgrx::pg_guard] unsafe extern "C" fn transform( subscript: *mut pgrx::pg_sys::SubscriptingRef, @@ -193,5 +194,5 @@ fn _vectors_veci8_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { fetch_leakproof: false, store_leakproof: false, }; - std::ptr::addr_of!(SBSROUTINES).into() + Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) } diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index 69c02ff35..3fca6435e 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -2,8 +2,6 @@ use super::memory_svecf32::SVecf32Output; use crate::datatype::memory_svecf32::SVecf32Input; use crate::datatype::typmod::Typmod; use crate::prelude::*; -use base::scalar::F32; -use base::vector::{SVecf32Borrowed, VectorBorrowed}; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; diff --git a/src/datatype/text_vecf16.rs b/src/datatype/text_vecf16.rs index fed1c2d68..5b5ca299d 100644 --- a/src/datatype/text_vecf16.rs +++ b/src/datatype/text_vecf16.rs @@ -2,7 +2,6 @@ use super::memory_vecf16::Vecf16Output; use crate::datatype::memory_vecf16::Vecf16Input; use crate::datatype::typmod::Typmod; use crate::prelude::*; -use base::vector::Vecf16Borrowed; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; diff --git a/src/datatype/text_vecf32.rs b/src/datatype/text_vecf32.rs index bd0c3000d..6927b0c24 100644 --- a/src/datatype/text_vecf32.rs +++ b/src/datatype/text_vecf32.rs @@ -2,7 +2,6 @@ use super::memory_vecf32::Vecf32Output; use crate::datatype::memory_vecf32::Vecf32Input; use crate::datatype::typmod::Typmod; use crate::prelude::*; -use base::vector::Vecf32Borrowed; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; diff --git a/src/datatype/text_veci8.rs b/src/datatype/text_veci8.rs index 094dbf7b3..b858ac8b9 100644 --- a/src/datatype/text_veci8.rs +++ b/src/datatype/text_veci8.rs @@ -1,7 +1,6 @@ use crate::datatype::memory_veci8::{Veci8Input, Veci8Output}; use crate::datatype::typmod::Typmod; use crate::prelude::*; -use base::vector::Veci8Borrowed; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; diff --git a/src/index/am.rs b/src/index/am.rs index 9f1e2983e..792071aca 100644 --- a/src/index/am.rs +++ b/src/index/am.rs @@ -8,6 +8,7 @@ use crate::gucs::planning::ENABLE_INDEX; use crate::index::utils::from_datum; use crate::prelude::*; use crate::utils::cells::PgCell; +use pgrx::datum::Internal; use pgrx::pg_sys::Datum; static RELOPT_KIND: PgCell = unsafe { PgCell::new(0) }; @@ -28,13 +29,12 @@ pub unsafe fn init() { #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_amhandler(internal) RETURNS index_am_handler IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_amhandler(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> pgrx::Internal { +fn _vectors_amhandler(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { type T = pgrx::pg_sys::IndexAmRoutine; unsafe { - use pgrx::FromDatum; let index_am_routine = pgrx::pg_sys::palloc0(std::mem::size_of::()) as *mut T; index_am_routine.write(AM_HANDLER); - pgrx::Internal::from_datum(Datum::from(index_am_routine), false).unwrap() + Internal::from(Some(Datum::from(index_am_routine))) } } diff --git a/src/prelude/error.rs b/src/prelude/error.rs index f96563c28..055349f26 100644 --- a/src/prelude/error.rs +++ b/src/prelude/error.rs @@ -122,9 +122,8 @@ pub fn check_connection(result: Result) -> T { match result { Err(_) => error!( "\ -pgvecto.rs: Indexes can only be built on built-in distance functions. -ADVICE: If you want pgvecto.rs to support more distance functions, \ -visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas." +pgvecto.rs: IPC connection is closed unexpectedly. +ADVICE: Visit `https://github.com/tensorchord/pgvecto.rs/issues` for help." ), Ok(x) => x, }