Skip to content

Commit b705494

Browse files
committed
[naga wgsl-in] Handle automatic type conversions for switch selector and case expressions
This allows abstract-typed expressions to be used for some or all of the switch selector and case selectors. If these are all not convertible to the same concrete scalar integer type we return an error. If all the selector expressions are abstract then they are concretized to i32. The note previously provided by the relevant error message, suggesting adding or removing the `u` suffix from case values, has been removed. While useful for simple literal values, it was comically incorrect for more complex case expressions. The error message should still be useful enough to allow the user to easily identify the problem.
1 parent 7df6e47 commit b705494

File tree

12 files changed

+713
-339
lines changed

12 files changed

+713
-339
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ By @Vecvec in [#6905](https://github.com/gfx-rs/wgpu/pull/6905), [#7086](https:/
199199
- Forward '--keep-coordinate-space' flag to GLSL backend in naga-cli. By @cloone8 in [#7206](https://github.com/gfx-rs/wgpu/pull/7206).
200200
- Allow template lists to have a trailing comma. By @KentSlaney in [#7142](https://github.com/gfx-rs/wgpu/pull/7142).
201201
- Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055) and [#7222](https://github.com/gfx-rs/wgpu/pull/7222).
202+
- Allow abstract types to be used for WGSL switch statement selector and case selector expressions. By @jamienicol in [#7250](https://github.com/gfx-rs/wgpu/pull/7250).
202203
203204
#### General
204205

naga/src/front/wgsl/error.rs

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,13 @@ pub(crate) enum Error<'a> {
256256
/// the same identifier as `ident`, above.
257257
path: Box<[(Span, Span)]>,
258258
},
259-
InvalidSwitchValue {
260-
uint: bool,
259+
InvalidSwitchSelector {
260+
span: Span,
261+
},
262+
InvalidSwitchCase {
263+
span: Span,
264+
},
265+
SwitchCaseTypeMismatch {
261266
span: Span,
262267
},
263268
CalledEntryPoint(Span),
@@ -772,26 +777,32 @@ impl<'a> Error<'a> {
772777
.collect(),
773778
notes: vec![],
774779
},
775-
Error::InvalidSwitchValue { uint, span } => ParseError {
776-
message: "invalid switch value".to_string(),
780+
Error::InvalidSwitchSelector { span } => ParseError {
781+
message: "invalid `switch` selector".to_string(),
777782
labels: vec![(
778783
span,
779-
if uint {
780-
"expected unsigned integer"
781-
} else {
782-
"expected signed integer"
783-
}
784+
"`switch` selector must be a scalar integer"
784785
.into(),
785786
)],
786-
notes: vec![if uint {
787-
format!("suffix the integer with a `u`: `{}u`", &source[span])
788-
} else {
789-
let span = span.to_range().unwrap();
790-
format!(
791-
"remove the `u` suffix: `{}`",
792-
&source[span.start..span.end - 1]
793-
)
794-
}],
787+
notes: vec![],
788+
},
789+
Error::InvalidSwitchCase { span } => ParseError {
790+
message: "invalid `switch` case selector value".to_string(),
791+
labels: vec![(
792+
span,
793+
"`switch` case selector must be a scalar integer const expression"
794+
.into(),
795+
)],
796+
notes: vec![],
797+
},
798+
Error::SwitchCaseTypeMismatch { span } => ParseError {
799+
message: "invalid `switch` case selector value".to_string(),
800+
labels: vec![(
801+
span,
802+
"`switch` case selector must have the same type as the `switch` selector expression"
803+
.into(),
804+
)],
805+
notes: vec![],
795806
},
796807
Error::CalledEntryPoint(span) => ParseError {
797808
message: "entry point cannot be called".to_string(),

naga/src/front/wgsl/lower/mod.rs

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,30 +1630,76 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16301630
emitter.start(&ctx.function.expressions);
16311631

16321632
let mut ectx = ctx.as_expression(block, &mut emitter);
1633-
let selector = self.expression(selector, &mut ectx)?;
16341633

1635-
let uint =
1636-
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
1634+
// Determine the scalar type of the selector and case expressions, find the
1635+
// consensus type for automatic conversion, then convert them.
1636+
let (mut exprs, spans) = core::iter::once(selector)
1637+
.chain(cases.iter().filter_map(|case| match case.value {
1638+
ast::SwitchValue::Expr(expr) => Some(expr),
1639+
ast::SwitchValue::Default => None,
1640+
}))
1641+
.enumerate()
1642+
.map(|(i, expr)| {
1643+
let span = ectx.ast_expressions.get_span(expr);
1644+
let expr = self.expression_for_abstract(expr, &mut ectx)?;
1645+
let ty = resolve_inner!(ectx, expr);
1646+
match *ty {
1647+
crate::TypeInner::Scalar(
1648+
crate::Scalar::I32
1649+
| crate::Scalar::U32
1650+
| crate::Scalar::ABSTRACT_INT,
1651+
) => Ok((expr, span)),
1652+
_ => match i {
1653+
0 => Err(Error::InvalidSwitchSelector { span }),
1654+
_ => Err(Error::InvalidSwitchCase { span }),
1655+
},
1656+
}
1657+
})
1658+
.collect::<Result<(Vec<_>, Vec<_>), _>>()?;
1659+
1660+
let mut consensus =
1661+
ectx.automatic_conversion_consensus(&exprs)
1662+
.map_err(|span_idx| Error::SwitchCaseTypeMismatch {
1663+
span: spans[span_idx],
1664+
})?;
1665+
// Concretize to I32 if the selector and all cases were abstract
1666+
if consensus == crate::Scalar::ABSTRACT_INT {
1667+
consensus = crate::Scalar::I32;
1668+
}
1669+
for expr in &mut exprs {
1670+
ectx.convert_to_leaf_scalar(expr, consensus)?;
1671+
}
1672+
16371673
block.extend(emitter.finish(&ctx.function.expressions));
16381674

1675+
let mut exprs = exprs.into_iter();
1676+
let selector = exprs
1677+
.next()
1678+
.expect("First element should be selector expression");
1679+
16391680
let cases = cases
16401681
.iter()
16411682
.map(|case| {
16421683
Ok(crate::SwitchCase {
16431684
value: match case.value {
16441685
ast::SwitchValue::Expr(expr) => {
16451686
let span = ctx.ast_expressions.get_span(expr);
1646-
let expr =
1647-
self.expression(expr, &mut ctx.as_global().as_const())?;
1648-
match ctx.module.to_ctx().eval_expr_to_literal(expr) {
1649-
Some(crate::Literal::I32(value)) if !uint => {
1687+
let expr = exprs.next().expect(
1688+
"Should yield expression for each SwitchValue::Expr case",
1689+
);
1690+
match ctx
1691+
.module
1692+
.to_ctx()
1693+
.eval_expr_to_literal_from(expr, &ctx.function.expressions)
1694+
{
1695+
Some(crate::Literal::I32(value)) => {
16501696
crate::SwitchValue::I32(value)
16511697
}
1652-
Some(crate::Literal::U32(value)) if uint => {
1698+
Some(crate::Literal::U32(value)) => {
16531699
crate::SwitchValue::U32(value)
16541700
}
16551701
_ => {
1656-
return Err(Error::InvalidSwitchValue { uint, span });
1702+
return Err(Error::InvalidSwitchCase { span });
16571703
}
16581704
}
16591705
}

naga/src/front/wgsl/tests.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ fn parse_parentheses_switch() {
322322
parse_str(
323323
"
324324
fn main() {
325-
var pos: f32;
326-
switch pos > 1.0 {
327-
default: { pos = 3.0; }
325+
var pos: i32;
326+
switch pos + 1 {
327+
default: { pos = 3; }
328328
}
329329
}
330330
",

naga/src/proc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ impl GlobalCtx<'_> {
455455
self.eval_expr_to_literal_from(handle, self.global_expressions)
456456
}
457457

458-
fn eval_expr_to_literal_from(
458+
pub(super) fn eval_expr_to_literal_from(
459459
&self,
460460
handle: crate::Handle<crate::Expression>,
461461
arena: &crate::Arena<crate::Expression>,

naga/tests/in/wgsl/control-flow.wgsl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,41 @@ fn switch_case_break() {
7979
return;
8080
}
8181

82+
fn switch_selector_type_conversion() {
83+
switch (0u) {
84+
case 0: {
85+
}
86+
default: {
87+
}
88+
}
89+
90+
switch (0) {
91+
case 0u: {
92+
}
93+
default: {
94+
}
95+
}
96+
}
97+
98+
const ONE = 1;
99+
fn switch_const_expr_case_selectors() {
100+
const TWO = 2;
101+
switch (0) {
102+
case i32(): {
103+
}
104+
case ONE: {
105+
}
106+
case TWO: {
107+
}
108+
case 1 + 2: {
109+
}
110+
case vec4(4).x: {
111+
}
112+
default: {
113+
}
114+
}
115+
}
116+
82117
fn loop_switch_continue(x: i32) {
83118
loop {
84119
switch x {

naga/tests/out/glsl/control-flow.main.Compute.glsl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@ void switch_case_break() {
2424
return;
2525
}
2626

27+
void switch_selector_type_conversion() {
28+
switch(0u) {
29+
case 0u: {
30+
break;
31+
}
32+
default: {
33+
break;
34+
}
35+
}
36+
switch(0u) {
37+
case 0u: {
38+
return;
39+
}
40+
default: {
41+
return;
42+
}
43+
}
44+
}
45+
46+
void switch_const_expr_case_selectors() {
47+
switch(0) {
48+
case 0: {
49+
return;
50+
}
51+
case 1: {
52+
return;
53+
}
54+
case 2: {
55+
return;
56+
}
57+
case 3: {
58+
return;
59+
}
60+
case 4: {
61+
return;
62+
}
63+
default: {
64+
return;
65+
}
66+
}
67+
}
68+
2769
void loop_switch_continue(int x) {
2870
while(true) {
2971
switch(x) {

naga/tests/out/hlsl/control-flow.hlsl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,50 @@ void switch_case_break()
1818
return;
1919
}
2020

21+
void switch_selector_type_conversion()
22+
{
23+
switch(0u) {
24+
case 0u: {
25+
break;
26+
}
27+
default: {
28+
break;
29+
}
30+
}
31+
switch(0u) {
32+
case 0u: {
33+
return;
34+
}
35+
default: {
36+
return;
37+
}
38+
}
39+
}
40+
41+
void switch_const_expr_case_selectors()
42+
{
43+
switch(int(0)) {
44+
case 0: {
45+
return;
46+
}
47+
case 1: {
48+
return;
49+
}
50+
case 2: {
51+
return;
52+
}
53+
case 3: {
54+
return;
55+
}
56+
case 4: {
57+
return;
58+
}
59+
default: {
60+
return;
61+
}
62+
}
63+
}
64+
2165
void loop_switch_continue(int x)
2266
{
2367
uint2 loop_bound = uint2(0u, 0u);

naga/tests/out/msl/control-flow.msl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,50 @@ void switch_case_break(
2828
return;
2929
}
3030

31+
void switch_selector_type_conversion(
32+
) {
33+
switch(0u) {
34+
case 0u: {
35+
break;
36+
}
37+
default: {
38+
break;
39+
}
40+
}
41+
switch(0u) {
42+
case 0u: {
43+
return;
44+
}
45+
default: {
46+
return;
47+
}
48+
}
49+
}
50+
51+
void switch_const_expr_case_selectors(
52+
) {
53+
switch(0) {
54+
case 0: {
55+
return;
56+
}
57+
case 1: {
58+
return;
59+
}
60+
case 2: {
61+
return;
62+
}
63+
case 3: {
64+
return;
65+
}
66+
case 4: {
67+
return;
68+
}
69+
default: {
70+
return;
71+
}
72+
}
73+
}
74+
3175
void loop_switch_continue(
3276
int x
3377
) {

0 commit comments

Comments
 (0)