Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga wgsl-in] Handle automatic type conversions for switch selector and case expressions #7250

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ By @Vecvec in [#6905](https://github.com/gfx-rs/wgpu/pull/6905), [#7086](https:/
- Allow template lists to have a trailing comma. By @KentSlaney in [#7142](https://github.com/gfx-rs/wgpu/pull/7142).
- Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055) and [#7222](https://github.com/gfx-rs/wgpu/pull/7222).
- Allows override-sized arrays to resolve to the same size without causing the type arena to panic. By @KentSlaney in [#7082](https://github.com/gfx-rs/wgpu/pull/7082).
- Allow abstract types to be used for WGSL switch statement selector and case selector expressions. By @jamienicol in [#7250](https://github.com/gfx-rs/wgpu/pull/7250).

#### General

Expand Down
47 changes: 29 additions & 18 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,13 @@ pub(crate) enum Error<'a> {
/// the same identifier as `ident`, above.
path: Box<[(Span, Span)]>,
},
InvalidSwitchValue {
uint: bool,
InvalidSwitchSelector {
span: Span,
},
InvalidSwitchCase {
span: Span,
},
SwitchCaseTypeMismatch {
span: Span,
},
CalledEntryPoint(Span),
Expand Down Expand Up @@ -772,26 +777,32 @@ impl<'a> Error<'a> {
.collect(),
notes: vec![],
},
Error::InvalidSwitchValue { uint, span } => ParseError {
message: "invalid switch value".to_string(),
Error::InvalidSwitchSelector { span } => ParseError {
message: "invalid `switch` selector".to_string(),
labels: vec![(
span,
if uint {
"expected unsigned integer"
} else {
"expected signed integer"
}
"`switch` selector must be a scalar integer"
.into(),
)],
notes: vec![if uint {
format!("suffix the integer with a `u`: `{}u`", &source[span])
} else {
let span = span.to_range().unwrap();
format!(
"remove the `u` suffix: `{}`",
&source[span.start..span.end - 1]
)
}],
notes: vec![],
},
Error::InvalidSwitchCase { span } => ParseError {
message: "invalid `switch` case selector value".to_string(),
labels: vec![(
span,
"`switch` case selector must be a scalar integer const expression"
.into(),
)],
notes: vec![],
},
Error::SwitchCaseTypeMismatch { span } => ParseError {
message: "invalid `switch` case selector value".to_string(),
labels: vec![(
span,
"`switch` case selector must have the same type as the `switch` selector expression"
.into(),
)],
notes: vec![],
},
Error::CalledEntryPoint(span) => ParseError {
message: "entry point cannot be called".to_string(),
Expand Down
64 changes: 55 additions & 9 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1630,30 +1630,76 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
emitter.start(&ctx.function.expressions);

let mut ectx = ctx.as_expression(block, &mut emitter);
let selector = self.expression(selector, &mut ectx)?;

let uint =
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
// Determine the scalar type of the selector and case expressions, find the
// consensus type for automatic conversion, then convert them.
let (mut exprs, spans) = core::iter::once(selector)
.chain(cases.iter().filter_map(|case| match case.value {
ast::SwitchValue::Expr(expr) => Some(expr),
ast::SwitchValue::Default => None,
}))
.enumerate()
.map(|(i, expr)| {
let span = ectx.ast_expressions.get_span(expr);
let expr = self.expression_for_abstract(expr, &mut ectx)?;
let ty = resolve_inner!(ectx, expr);
match *ty {
crate::TypeInner::Scalar(
crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::ABSTRACT_INT,
) => Ok((expr, span)),
_ => match i {
0 => Err(Error::InvalidSwitchSelector { span }),
_ => Err(Error::InvalidSwitchCase { span }),
},
}
})
.collect::<Result<(Vec<_>, Vec<_>), _>>()?;

let mut consensus =
ectx.automatic_conversion_consensus(&exprs)
.map_err(|span_idx| Error::SwitchCaseTypeMismatch {
span: spans[span_idx],
})?;
// Concretize to I32 if the selector and all cases were abstract
if consensus == crate::Scalar::ABSTRACT_INT {
consensus = crate::Scalar::I32;
}
for expr in &mut exprs {
ectx.convert_to_leaf_scalar(expr, consensus)?;
}

block.extend(emitter.finish(&ctx.function.expressions));

let mut exprs = exprs.into_iter();
let selector = exprs
.next()
.expect("First element should be selector expression");

let cases = cases
.iter()
.map(|case| {
Ok(crate::SwitchCase {
value: match case.value {
ast::SwitchValue::Expr(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let expr =
self.expression(expr, &mut ctx.as_global().as_const())?;
match ctx.module.to_ctx().eval_expr_to_literal(expr) {
Some(crate::Literal::I32(value)) if !uint => {
let expr = exprs.next().expect(
"Should yield expression for each SwitchValue::Expr case",
);
match ctx
.module
.to_ctx()
.eval_expr_to_literal_from(expr, &ctx.function.expressions)
{
Some(crate::Literal::I32(value)) => {
crate::SwitchValue::I32(value)
}
Some(crate::Literal::U32(value)) if uint => {
Some(crate::Literal::U32(value)) => {
crate::SwitchValue::U32(value)
}
_ => {
return Err(Error::InvalidSwitchValue { uint, span });
return Err(Error::InvalidSwitchCase { span });
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions naga/src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ fn parse_parentheses_switch() {
parse_str(
"
fn main() {
var pos: f32;
switch pos > 1.0 {
default: { pos = 3.0; }
var pos: i32;
switch pos + 1 {
default: { pos = 3; }
}
}
",
Expand Down
2 changes: 1 addition & 1 deletion naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ impl GlobalCtx<'_> {
self.eval_expr_to_literal_from(handle, self.global_expressions)
}

fn eval_expr_to_literal_from(
pub(super) fn eval_expr_to_literal_from(
&self,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
Expand Down
35 changes: 35 additions & 0 deletions naga/tests/in/wgsl/control-flow.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,41 @@ fn switch_case_break() {
return;
}

fn switch_selector_type_conversion() {
switch (0u) {
case 0: {
}
default: {
}
}

switch (0) {
case 0u: {
}
default: {
}
}
}

const ONE = 1;
fn switch_const_expr_case_selectors() {
const TWO = 2;
switch (0) {
case i32(): {
}
case ONE: {
}
case TWO: {
}
case 1 + 2: {
}
case vec4(4).x: {
}
default: {
}
}
}

fn loop_switch_continue(x: i32) {
loop {
switch x {
Expand Down
42 changes: 42 additions & 0 deletions naga/tests/out/glsl/control-flow.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,48 @@ void switch_case_break() {
return;
}

void switch_selector_type_conversion() {
switch(0u) {
case 0u: {
break;
}
default: {
break;
}
}
switch(0u) {
case 0u: {
return;
}
default: {
return;
}
}
}

void switch_const_expr_case_selectors() {
switch(0) {
case 0: {
return;
}
case 1: {
return;
}
case 2: {
return;
}
case 3: {
return;
}
case 4: {
return;
}
default: {
return;
}
}
}

void loop_switch_continue(int x) {
while(true) {
switch(x) {
Expand Down
44 changes: 44 additions & 0 deletions naga/tests/out/hlsl/control-flow.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@ void switch_case_break()
return;
}

void switch_selector_type_conversion()
{
switch(0u) {
case 0u: {
break;
}
default: {
break;
}
}
switch(0u) {
case 0u: {
return;
}
default: {
return;
}
}
}

void switch_const_expr_case_selectors()
{
switch(int(0)) {
case 0: {
return;
}
case 1: {
return;
}
case 2: {
return;
}
case 3: {
return;
}
case 4: {
return;
}
default: {
return;
}
}
}

void loop_switch_continue(int x)
{
uint2 loop_bound = uint2(0u, 0u);
Expand Down
44 changes: 44 additions & 0 deletions naga/tests/out/msl/control-flow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,50 @@ void switch_case_break(
return;
}

void switch_selector_type_conversion(
) {
switch(0u) {
case 0u: {
break;
}
default: {
break;
}
}
switch(0u) {
case 0u: {
return;
}
default: {
return;
}
}
}

void switch_const_expr_case_selectors(
) {
switch(0) {
case 0: {
return;
}
case 1: {
return;
}
case 2: {
return;
}
case 3: {
return;
}
case 4: {
return;
}
default: {
return;
}
}
}

void loop_switch_continue(
int x
) {
Expand Down
Loading