Skip to content

Commit

Permalink
[naga wgsl-in] Handle automatic type conversions for switch selector …
Browse files Browse the repository at this point in the history
…and case expressions

This allows abstract-typed expressions to be used for some or all of
the switch selector and case selectors. If these are all not
convertible to the same concrete scalar integer type we return an
error. If all the selector expressions are abstract then they are
concretized to i32.

The note previously provided by the relevant error message, suggesting
adding or removing the `u` suffix from case values, has been
removed. While useful for simple literal values, it was comically
incorrect for more complex case expressions. The error message should
still be useful enough to allow the user to easily identify the
problem.
  • Loading branch information
jamienicol committed Feb 28, 2025
1 parent 3297e9f commit b941b89
Show file tree
Hide file tree
Showing 12 changed files with 700 additions and 326 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
- Forward '--keep-coordinate-space' flag to GLSL backend in naga-cli. By @cloone8 in [#7206](https://github.com/gfx-rs/wgpu/pull/7206).
- Allow template lists to have a trailing comma. By @KentSlaney in [#7142](https://github.com/gfx-rs/wgpu/pull/7142).
- Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055).
- Allow abstract types to be used for WGSL switch statement selector and case selector expressions. By @jamienicol in [#7250](https://github.com/gfx-rs/wgpu/pull/7250).
#### General
Expand Down
47 changes: 29 additions & 18 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,13 @@ pub(crate) enum Error<'a> {
/// the same identifier as `ident`, above.
path: Box<[(Span, Span)]>,
},
InvalidSwitchValue {
uint: bool,
InvalidSwitchSelector {
span: Span,
},
InvalidSwitchCase {
span: Span,
},
SwitchCaseTypeMismatch {
span: Span,
},
CalledEntryPoint(Span),
Expand Down Expand Up @@ -763,26 +768,32 @@ impl<'a> Error<'a> {
.collect(),
notes: vec![],
},
Error::InvalidSwitchValue { uint, span } => ParseError {
message: "invalid switch value".to_string(),
Error::InvalidSwitchSelector { span } => ParseError {
message: "invalid switch selector".to_string(),
labels: vec![(
span,
if uint {
"expected unsigned integer"
} else {
"expected signed integer"
}
"switch selector must be a scalar integer"
.into(),
)],
notes: vec![if uint {
format!("suffix the integer with a `u`: `{}u`", &source[span])
} else {
let span = span.to_range().unwrap();
format!(
"remove the `u` suffix: `{}`",
&source[span.start..span.end - 1]
)
}],
notes: vec![],
},
Error::InvalidSwitchCase { span } => ParseError {
message: "invalid switch case value".to_string(),
labels: vec![(
span,
"switch case selector must be a scalar integer const expression"
.into(),
)],
notes: vec![],
},
Error::SwitchCaseTypeMismatch { span } => ParseError {
message: "invalid switch case selector value".to_string(),
labels: vec![(
span,
"switch case selector must have the same type as the switch selector expression"
.into(),
)],
notes: vec![],
},
Error::CalledEntryPoint(span) => ParseError {
message: "entry point cannot be called".to_string(),
Expand Down
64 changes: 55 additions & 9 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1621,30 +1621,76 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
emitter.start(&ctx.function.expressions);

let mut ectx = ctx.as_expression(block, &mut emitter);
let selector = self.expression(selector, &mut ectx)?;

let uint =
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
// Determine the scalar type of the selector and case expressions, find the
// consensus type for automatic conversion, then convert them.
let (mut exprs, spans) = std::iter::once(selector)
.chain(cases.iter().filter_map(|case| match case.value {
ast::SwitchValue::Expr(expr) => Some(expr),
ast::SwitchValue::Default => None,
}))
.enumerate()
.map(|(i, expr)| {
let span = ectx.ast_expressions.get_span(expr);
let expr = self.expression_for_abstract(expr, &mut ectx)?;
let ty = resolve_inner!(ectx, expr);
match *ty {
crate::TypeInner::Scalar(
crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::ABSTRACT_INT,
) => Ok((expr, span)),
_ => match i {
0 => Err(Error::InvalidSwitchSelector { span }),
_ => Err(Error::InvalidSwitchCase { span }),
},
}
})
.collect::<Result<(Vec<_>, Vec<_>), _>>()?;

let mut consensus =
ectx.automatic_conversion_consensus(&exprs)
.map_err(|span_idx| Error::SwitchCaseTypeMismatch {
span: spans[span_idx],
})?;
// Concretize to I32 if the selector and all cases were abstract
if consensus == crate::Scalar::ABSTRACT_INT {
consensus = crate::Scalar::I32;
}
for expr in &mut exprs {
ectx.convert_to_leaf_scalar(expr, consensus)?;
}

block.extend(emitter.finish(&ctx.function.expressions));

let mut exprs = exprs.into_iter();
let selector = exprs
.next()
.expect("First element should be selector expression");

let cases = cases
.iter()
.map(|case| {
Ok(crate::SwitchCase {
value: match case.value {
ast::SwitchValue::Expr(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let expr =
self.expression(expr, &mut ctx.as_global().as_const())?;
match ctx.module.to_ctx().eval_expr_to_literal(expr) {
Some(crate::Literal::I32(value)) if !uint => {
let expr = exprs.next().expect(
"Should yield expression for each SwitchValue::Expr case",
);
match ctx
.module
.to_ctx()
.eval_expr_to_literal_from(expr, &ctx.function.expressions)
{
Some(crate::Literal::I32(value)) => {
crate::SwitchValue::I32(value)
}
Some(crate::Literal::U32(value)) if uint => {
Some(crate::Literal::U32(value)) => {
crate::SwitchValue::U32(value)
}
_ => {
return Err(Error::InvalidSwitchValue { uint, span });
return Err(Error::InvalidSwitchCase { span });
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions naga/src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ fn parse_parentheses_switch() {
parse_str(
"
fn main() {
var pos: f32;
switch pos > 1.0 {
default: { pos = 3.0; }
var pos: i32;
switch pos + 1 {
default: { pos = 3; }
}
}
",
Expand Down
2 changes: 1 addition & 1 deletion naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ impl GlobalCtx<'_> {
self.eval_expr_to_literal_from(handle, self.global_expressions)
}

fn eval_expr_to_literal_from(
pub(super) fn eval_expr_to_literal_from(
&self,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
Expand Down
35 changes: 35 additions & 0 deletions naga/tests/in/wgsl/control-flow.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,41 @@ fn switch_case_break() {
return;
}

fn switch_selector_type_conversion() {
switch (0u) {
case 0: {
}
default: {
}
}

switch (0) {
case 0u: {
}
default: {
}
}
}

const ONE = 1;
fn switch_const_expr_case_selectors() {
const TWO = 2;
switch (0) {
case i32(): {
}
case ONE: {
}
case TWO: {
}
case 1 + 2: {
}
case vec4(4).x: {
}
default: {
}
}
}

fn loop_switch_continue(x: i32) {
loop {
switch x {
Expand Down
42 changes: 42 additions & 0 deletions naga/tests/out/glsl/control-flow.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,48 @@ void switch_case_break() {
return;
}

void switch_selector_type_conversion() {
switch(0u) {
case 0u: {
break;
}
default: {
break;
}
}
switch(0u) {
case 0u: {
return;
}
default: {
return;
}
}
}

void switch_const_expr_case_selectors() {
switch(0) {
case 0: {
return;
}
case 1: {
return;
}
case 2: {
return;
}
case 3: {
return;
}
case 4: {
return;
}
default: {
return;
}
}
}

void loop_switch_continue(int x) {
while(true) {
switch(x) {
Expand Down
44 changes: 44 additions & 0 deletions naga/tests/out/hlsl/control-flow.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@ void switch_case_break()
return;
}

void switch_selector_type_conversion()
{
switch(0u) {
case 0u: {
break;
}
default: {
break;
}
}
switch(0u) {
case 0u: {
return;
}
default: {
return;
}
}
}

void switch_const_expr_case_selectors()
{
switch(int(0)) {
case 0: {
return;
}
case 1: {
return;
}
case 2: {
return;
}
case 3: {
return;
}
case 4: {
return;
}
default: {
return;
}
}
}

void loop_switch_continue(int x)
{
uint2 loop_bound = uint2(0u, 0u);
Expand Down
44 changes: 44 additions & 0 deletions naga/tests/out/msl/control-flow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,50 @@ void switch_case_break(
return;
}

void switch_selector_type_conversion(
) {
switch(0u) {
case 0u: {
break;
}
default: {
break;
}
}
switch(0u) {
case 0u: {
return;
}
default: {
return;
}
}
}

void switch_const_expr_case_selectors(
) {
switch(0) {
case 0: {
return;
}
case 1: {
return;
}
case 2: {
return;
}
case 3: {
return;
}
case 4: {
return;
}
default: {
return;
}
}
}

void loop_switch_continue(
int x
) {
Expand Down
Loading

0 comments on commit b941b89

Please sign in to comment.