Skip to content

Commit b0a91df

Browse files
committed
Use function pointers for AVX-512 path
perf stat -d -d -d --delay 250 ./bench/run_func data/twitter.json.xz dumps 5000 7349323334 instructions:u 1503705855 branches:u 7064976507 instructions:u 1395888450 branches:u
1 parent 56d6457 commit b0a91df

File tree

7 files changed

+56
-33
lines changed

7 files changed

+56
-33
lines changed

src/serialize/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ mod obtype;
66
mod per_type;
77
mod serializer;
88
mod state;
9-
mod writer;
9+
pub mod writer;
1010

1111
pub use serializer::serialize;

src/serialize/writer/json.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,22 @@ where
590590
}
591591
}
592592

593+
#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))]
594+
type StrFormatter = unsafe fn(*mut u8, *const u8, usize) -> usize;
595+
596+
#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))]
597+
static mut STR_FORMATTER_FN: StrFormatter =
598+
crate::serialize::writer::str::format_escaped_str_impl_sse2_128;
599+
600+
pub fn set_str_formatter_fn() {
601+
unsafe {
602+
#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))]
603+
if std::is_x86_feature_detected!("avx512vl") {
604+
STR_FORMATTER_FN = crate::serialize::writer::str::format_escaped_str_impl_512vl;
605+
}
606+
}
607+
}
608+
593609
#[cfg(all(
594610
feature = "unstable-simd",
595611
target_arch = "x86_64",
@@ -622,21 +638,13 @@ where
622638
unsafe {
623639
reserve_str!(writer, value);
624640

625-
if std::is_x86_feature_detected!("avx512vl") {
626-
let written = crate::serialize::writer::str::format_escaped_str_impl_512vl(
627-
writer.as_mut_buffer_ptr(),
628-
value.as_bytes().as_ptr(),
629-
value.len(),
630-
);
631-
writer.set_written(written);
632-
} else {
633-
let written = crate::serialize::writer::str::format_escaped_str_impl_sse2_128(
634-
writer.as_mut_buffer_ptr(),
635-
value.as_bytes().as_ptr(),
636-
value.len(),
637-
);
638-
writer.set_written(written);
639-
}
641+
let written = STR_FORMATTER_FN(
642+
writer.as_mut_buffer_ptr(),
643+
value.as_bytes().as_ptr(),
644+
value.len(),
645+
);
646+
647+
writer.set_written(written);
640648
}
641649
}
642650

@@ -654,6 +662,7 @@ where
654662
value.as_bytes().as_ptr(),
655663
value.len(),
656664
);
665+
657666
writer.set_written(written);
658667
}
659668
}

src/serialize/writer/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ mod json;
66
mod str;
77

88
pub use byteswriter::{BytesWriter, WriteExt};
9-
pub use json::{to_writer, to_writer_pretty};
9+
pub use json::{set_str_formatter_fn, to_writer, to_writer_pretty};

src/str/avx512.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,18 @@ pub fn unicode_from_str(buf: &str) -> *mut pyo3_ffi::PyObject {
9393
if unlikely!(buf.is_empty()) {
9494
return use_immortal!(crate::typeref::EMPTY_UNICODE);
9595
}
96+
STR_CREATE_FN(buf)
97+
}
98+
}
99+
100+
pub type StrDeserializer = unsafe fn(&str) -> *mut pyo3_ffi::PyObject;
101+
102+
static mut STR_CREATE_FN: StrDeserializer = super::scalar::str_impl_kind_scalar;
103+
104+
pub fn set_str_create_fn() {
105+
unsafe {
96106
if std::is_x86_feature_detected!("avx512vl") {
97-
create_str_impl_avx512vl(buf)
98-
} else {
99-
super::scalar::unicode_from_str(buf)
107+
STR_CREATE_FN = create_str_impl_avx512vl;
100108
}
101109
}
102110
}

src/str/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ mod pyunicode_new;
77
mod scalar;
88

99
#[cfg(not(feature = "avx512"))]
10-
pub use scalar::unicode_from_str;
10+
pub use scalar::{set_str_create_fn, unicode_from_str};
1111

1212
#[cfg(feature = "avx512")]
13-
pub use avx512::unicode_from_str;
13+
pub use avx512::{set_str_create_fn, unicode_from_str};
1414

1515
pub use ffi::*;

src/str/scalar.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
use crate::str::pyunicode_new::{
44
pyunicode_ascii, pyunicode_fourbyte, pyunicode_onebyte, pyunicode_twobyte,
55
};
6-
use crate::typeref::EMPTY_UNICODE;
76

8-
#[inline(always)]
9-
pub fn str_impl_kind_scalar(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObject {
7+
#[inline(never)]
8+
pub fn str_impl_kind_scalar(buf: &str) -> *mut pyo3_ffi::PyObject {
9+
let num_chars = bytecount::num_chars(buf.as_bytes());
10+
if buf.len() == num_chars {
11+
return pyunicode_ascii(buf.as_ptr(), num_chars);
12+
}
1013
unsafe {
1114
let len = buf.len();
1215
assume!(len > 0);
@@ -33,15 +36,14 @@ pub fn str_impl_kind_scalar(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObj
3336
}
3437
}
3538

36-
#[inline(never)]
39+
#[cfg(not(feature = "avx512"))]
40+
#[inline(always)]
3741
pub fn unicode_from_str(buf: &str) -> *mut pyo3_ffi::PyObject {
3842
if unlikely!(buf.is_empty()) {
39-
return use_immortal!(EMPTY_UNICODE);
40-
}
41-
let num_chars = bytecount::num_chars(buf.as_bytes());
42-
if buf.len() == num_chars {
43-
pyunicode_ascii(buf.as_ptr(), num_chars)
44-
} else {
45-
str_impl_kind_scalar(buf, num_chars)
43+
return use_immortal!(crate::typeref::EMPTY_UNICODE);
4644
}
45+
str_impl_kind_scalar(buf)
4746
}
47+
48+
#[cfg(not(feature = "avx512"))]
49+
pub fn set_str_create_fn() {}

src/typeref.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ fn _init_typerefs_impl() -> bool {
140140
assert!(crate::deserialize::KEY_MAP
141141
.set(crate::deserialize::KeyMap::default())
142142
.is_ok());
143+
144+
crate::serialize::writer::set_str_formatter_fn();
145+
crate::str::set_str_create_fn();
146+
143147
FRAGMENT_TYPE = orjson_fragmenttype_new();
144148
PyDateTime_IMPORT();
145149
NONE = Py_None();

0 commit comments

Comments
 (0)