Skip to content

Commit

Permalink
count_ones: fix bit_reverse, must be u32-only in vulkan
Browse files Browse the repository at this point in the history
  • Loading branch information
Firestar99 committed Feb 10, 2025
1 parent 6f72d86 commit 10f0283
Showing 1 changed file with 73 additions and 5 deletions.
78 changes: 73 additions & 5 deletions crates/rustc_codegen_spirv/src/builder/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
sym::cttz_nonzero => self.count_leading_trailing_zeros(args[0].immediate(), true, true),

sym::ctpop => self.count_ones(args[0].immediate()),
sym::bitreverse => self
.emit()
.bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self))
.unwrap()
.with_type(args[0].immediate().ty),
sym::bitreverse => self.bit_reverse(args[0].immediate()),
sym::bswap => {
// https://github.com/KhronosGroup/SPIRV-LLVM/pull/221/files
// TODO: Definitely add tests to make sure this impl is right.
Expand Down Expand Up @@ -418,6 +414,78 @@ impl Builder<'_, '_> {
_ => self.fatal("count_ones on a non-integer type"),
}
}
pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, signed) => {
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
let uint = SpirvType::Integer(bits, false).def(self.span(), self);

match (bits, signed) {
(8 | 16, signed) => {
let arg = arg.def(self);
let arg = if signed {
self.emit().bitcast(uint, None, arg).unwrap()
} else {
arg
};
let arg = self.emit().u_convert(u32, None, arg).unwrap();

let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
let shift = self.constant_u32(self.span(), 32 - bits).def(self);
let reverse = self.emit().shift_right_logical(u32, None, reverse, shift).unwrap();
let reverse = self.emit().u_convert(uint, None, reverse).unwrap();
if signed {
self.emit().bitcast(ty, None, reverse).unwrap()
} else {
reverse
}
}
(32, false) => self.emit().bit_reverse(u32, None, arg.def(self)).unwrap(),
(32, true) => {
let arg = self.emit().bitcast(u32, None, arg.def(self)).unwrap();
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
self.emit().bitcast(ty, None, reverse).unwrap()
},
(64, signed) => {
let u32_32 = self.constant_u32(self.span(), 32).def(self);
let arg = arg.def(self);
let lower = self.emit().s_convert(u32, None, arg).unwrap();
let higher = self
.emit()
.shift_left_logical(ty, None, arg, u32_32)
.unwrap();
let higher = self.emit().s_convert(u32, None, higher).unwrap();

// note that higher and lower have swapped
let higher_bits = self.emit().bit_reverse(u32, None, lower).unwrap();
let lower_bits = self.emit().bit_reverse(u32, None, higher).unwrap();

let higher_bits = self.emit().u_convert(uint, None, higher_bits).unwrap();
let shift = self.constant_u32(self.span(), 32).def(self);
let higher_bits = self.emit().shift_right_logical(uint, None, higher_bits, shift).unwrap();
let lower_bits = self.emit().u_convert(uint, None, lower_bits).unwrap();

let result = self.emit().bitwise_or(ty, None, lower_bits, higher_bits).unwrap();
if signed {
self.emit().bitcast(ty, None, result).unwrap()
} else {
result
}
}
_ => {
let undef = self.undef(ty).def(self);
self.zombie(undef, &format!(
"counting leading / trailing zeros on unsupported {ty:?} bit integer type"
));
undef
}
}
.with_type(ty)
}
_ => self.fatal("count_ones on a non-integer type"),
}
}

pub fn count_leading_trailing_zeros(
&self,
Expand Down

0 comments on commit 10f0283

Please sign in to comment.