Skip to content

Commit b941b89

Browse files
committed
[naga wgsl-in] Handle automatic type conversions for switch selector and case expressions
This allows abstract-typed expressions to be used for some or all of the switch selector and case selectors. If these are all not convertible to the same concrete scalar integer type we return an error. If all the selector expressions are abstract then they are concretized to i32. The note previously provided by the relevant error message, suggesting adding or removing the `u` suffix from case values, has been removed. While useful for simple literal values, it was comically incorrect for more complex case expressions. The error message should still be useful enough to allow the user to easily identify the problem.
1 parent 3297e9f commit b941b89

File tree

12 files changed

+700
-326
lines changed

12 files changed

+700
-326
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
195195
- Forward '--keep-coordinate-space' flag to GLSL backend in naga-cli. By @cloone8 in [#7206](https://github.com/gfx-rs/wgpu/pull/7206).
196196
- Allow template lists to have a trailing comma. By @KentSlaney in [#7142](https://github.com/gfx-rs/wgpu/pull/7142).
197197
- Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055).
198+
- Allow abstract types to be used for WGSL switch statement selector and case selector expressions. By @jamienicol in [#7250](https://github.com/gfx-rs/wgpu/pull/7250).
198199
199200
#### General
200201

naga/src/front/wgsl/error.rs

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,13 @@ pub(crate) enum Error<'a> {
247247
/// the same identifier as `ident`, above.
248248
path: Box<[(Span, Span)]>,
249249
},
250-
InvalidSwitchValue {
251-
uint: bool,
250+
InvalidSwitchSelector {
251+
span: Span,
252+
},
253+
InvalidSwitchCase {
254+
span: Span,
255+
},
256+
SwitchCaseTypeMismatch {
252257
span: Span,
253258
},
254259
CalledEntryPoint(Span),
@@ -763,26 +768,32 @@ impl<'a> Error<'a> {
763768
.collect(),
764769
notes: vec![],
765770
},
766-
Error::InvalidSwitchValue { uint, span } => ParseError {
767-
message: "invalid switch value".to_string(),
771+
Error::InvalidSwitchSelector { span } => ParseError {
772+
message: "invalid switch selector".to_string(),
768773
labels: vec![(
769774
span,
770-
if uint {
771-
"expected unsigned integer"
772-
} else {
773-
"expected signed integer"
774-
}
775+
"switch selector must be a scalar integer"
775776
.into(),
776777
)],
777-
notes: vec![if uint {
778-
format!("suffix the integer with a `u`: `{}u`", &source[span])
779-
} else {
780-
let span = span.to_range().unwrap();
781-
format!(
782-
"remove the `u` suffix: `{}`",
783-
&source[span.start..span.end - 1]
784-
)
785-
}],
778+
notes: vec![],
779+
},
780+
Error::InvalidSwitchCase { span } => ParseError {
781+
message: "invalid switch case value".to_string(),
782+
labels: vec![(
783+
span,
784+
"switch case selector must be a scalar integer const expression"
785+
.into(),
786+
)],
787+
notes: vec![],
788+
},
789+
Error::SwitchCaseTypeMismatch { span } => ParseError {
790+
message: "invalid switch case selector value".to_string(),
791+
labels: vec![(
792+
span,
793+
"switch case selector must have the same type as the switch selector expression"
794+
.into(),
795+
)],
796+
notes: vec![],
786797
},
787798
Error::CalledEntryPoint(span) => ParseError {
788799
message: "entry point cannot be called".to_string(),

naga/src/front/wgsl/lower/mod.rs

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,30 +1621,76 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16211621
emitter.start(&ctx.function.expressions);
16221622

16231623
let mut ectx = ctx.as_expression(block, &mut emitter);
1624-
let selector = self.expression(selector, &mut ectx)?;
16251624

1626-
let uint =
1627-
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
1625+
// Determine the scalar type of the selector and case expressions, find the
1626+
// consensus type for automatic conversion, then convert them.
1627+
let (mut exprs, spans) = std::iter::once(selector)
1628+
.chain(cases.iter().filter_map(|case| match case.value {
1629+
ast::SwitchValue::Expr(expr) => Some(expr),
1630+
ast::SwitchValue::Default => None,
1631+
}))
1632+
.enumerate()
1633+
.map(|(i, expr)| {
1634+
let span = ectx.ast_expressions.get_span(expr);
1635+
let expr = self.expression_for_abstract(expr, &mut ectx)?;
1636+
let ty = resolve_inner!(ectx, expr);
1637+
match *ty {
1638+
crate::TypeInner::Scalar(
1639+
crate::Scalar::I32
1640+
| crate::Scalar::U32
1641+
| crate::Scalar::ABSTRACT_INT,
1642+
) => Ok((expr, span)),
1643+
_ => match i {
1644+
0 => Err(Error::InvalidSwitchSelector { span }),
1645+
_ => Err(Error::InvalidSwitchCase { span }),
1646+
},
1647+
}
1648+
})
1649+
.collect::<Result<(Vec<_>, Vec<_>), _>>()?;
1650+
1651+
let mut consensus =
1652+
ectx.automatic_conversion_consensus(&exprs)
1653+
.map_err(|span_idx| Error::SwitchCaseTypeMismatch {
1654+
span: spans[span_idx],
1655+
})?;
1656+
// Concretize to I32 if the selector and all cases were abstract
1657+
if consensus == crate::Scalar::ABSTRACT_INT {
1658+
consensus = crate::Scalar::I32;
1659+
}
1660+
for expr in &mut exprs {
1661+
ectx.convert_to_leaf_scalar(expr, consensus)?;
1662+
}
1663+
16281664
block.extend(emitter.finish(&ctx.function.expressions));
16291665

1666+
let mut exprs = exprs.into_iter();
1667+
let selector = exprs
1668+
.next()
1669+
.expect("First element should be selector expression");
1670+
16301671
let cases = cases
16311672
.iter()
16321673
.map(|case| {
16331674
Ok(crate::SwitchCase {
16341675
value: match case.value {
16351676
ast::SwitchValue::Expr(expr) => {
16361677
let span = ctx.ast_expressions.get_span(expr);
1637-
let expr =
1638-
self.expression(expr, &mut ctx.as_global().as_const())?;
1639-
match ctx.module.to_ctx().eval_expr_to_literal(expr) {
1640-
Some(crate::Literal::I32(value)) if !uint => {
1678+
let expr = exprs.next().expect(
1679+
"Should yield expression for each SwitchValue::Expr case",
1680+
);
1681+
match ctx
1682+
.module
1683+
.to_ctx()
1684+
.eval_expr_to_literal_from(expr, &ctx.function.expressions)
1685+
{
1686+
Some(crate::Literal::I32(value)) => {
16411687
crate::SwitchValue::I32(value)
16421688
}
1643-
Some(crate::Literal::U32(value)) if uint => {
1689+
Some(crate::Literal::U32(value)) => {
16441690
crate::SwitchValue::U32(value)
16451691
}
16461692
_ => {
1647-
return Err(Error::InvalidSwitchValue { uint, span });
1693+
return Err(Error::InvalidSwitchCase { span });
16481694
}
16491695
}
16501696
}

naga/src/front/wgsl/tests.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,9 @@ fn parse_parentheses_switch() {
320320
parse_str(
321321
"
322322
fn main() {
323-
var pos: f32;
324-
switch pos > 1.0 {
325-
default: { pos = 3.0; }
323+
var pos: i32;
324+
switch pos + 1 {
325+
default: { pos = 3; }
326326
}
327327
}
328328
",

naga/src/proc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ impl GlobalCtx<'_> {
455455
self.eval_expr_to_literal_from(handle, self.global_expressions)
456456
}
457457

458-
fn eval_expr_to_literal_from(
458+
pub(super) fn eval_expr_to_literal_from(
459459
&self,
460460
handle: crate::Handle<crate::Expression>,
461461
arena: &crate::Arena<crate::Expression>,

naga/tests/in/wgsl/control-flow.wgsl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,41 @@ fn switch_case_break() {
7979
return;
8080
}
8181

82+
fn switch_selector_type_conversion() {
83+
switch (0u) {
84+
case 0: {
85+
}
86+
default: {
87+
}
88+
}
89+
90+
switch (0) {
91+
case 0u: {
92+
}
93+
default: {
94+
}
95+
}
96+
}
97+
98+
const ONE = 1;
99+
fn switch_const_expr_case_selectors() {
100+
const TWO = 2;
101+
switch (0) {
102+
case i32(): {
103+
}
104+
case ONE: {
105+
}
106+
case TWO: {
107+
}
108+
case 1 + 2: {
109+
}
110+
case vec4(4).x: {
111+
}
112+
default: {
113+
}
114+
}
115+
}
116+
82117
fn loop_switch_continue(x: i32) {
83118
loop {
84119
switch x {

naga/tests/out/glsl/control-flow.main.Compute.glsl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@ void switch_case_break() {
2424
return;
2525
}
2626

27+
void switch_selector_type_conversion() {
28+
switch(0u) {
29+
case 0u: {
30+
break;
31+
}
32+
default: {
33+
break;
34+
}
35+
}
36+
switch(0u) {
37+
case 0u: {
38+
return;
39+
}
40+
default: {
41+
return;
42+
}
43+
}
44+
}
45+
46+
void switch_const_expr_case_selectors() {
47+
switch(0) {
48+
case 0: {
49+
return;
50+
}
51+
case 1: {
52+
return;
53+
}
54+
case 2: {
55+
return;
56+
}
57+
case 3: {
58+
return;
59+
}
60+
case 4: {
61+
return;
62+
}
63+
default: {
64+
return;
65+
}
66+
}
67+
}
68+
2769
void loop_switch_continue(int x) {
2870
while(true) {
2971
switch(x) {

naga/tests/out/hlsl/control-flow.hlsl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,50 @@ void switch_case_break()
1818
return;
1919
}
2020

21+
void switch_selector_type_conversion()
22+
{
23+
switch(0u) {
24+
case 0u: {
25+
break;
26+
}
27+
default: {
28+
break;
29+
}
30+
}
31+
switch(0u) {
32+
case 0u: {
33+
return;
34+
}
35+
default: {
36+
return;
37+
}
38+
}
39+
}
40+
41+
void switch_const_expr_case_selectors()
42+
{
43+
switch(int(0)) {
44+
case 0: {
45+
return;
46+
}
47+
case 1: {
48+
return;
49+
}
50+
case 2: {
51+
return;
52+
}
53+
case 3: {
54+
return;
55+
}
56+
case 4: {
57+
return;
58+
}
59+
default: {
60+
return;
61+
}
62+
}
63+
}
64+
2165
void loop_switch_continue(int x)
2266
{
2367
uint2 loop_bound = uint2(0u, 0u);

naga/tests/out/msl/control-flow.msl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,50 @@ void switch_case_break(
2828
return;
2929
}
3030

31+
void switch_selector_type_conversion(
32+
) {
33+
switch(0u) {
34+
case 0u: {
35+
break;
36+
}
37+
default: {
38+
break;
39+
}
40+
}
41+
switch(0u) {
42+
case 0u: {
43+
return;
44+
}
45+
default: {
46+
return;
47+
}
48+
}
49+
}
50+
51+
void switch_const_expr_case_selectors(
52+
) {
53+
switch(0) {
54+
case 0: {
55+
return;
56+
}
57+
case 1: {
58+
return;
59+
}
60+
case 2: {
61+
return;
62+
}
63+
case 3: {
64+
return;
65+
}
66+
case 4: {
67+
return;
68+
}
69+
default: {
70+
return;
71+
}
72+
}
73+
}
74+
3175
void loop_switch_continue(
3276
int x
3377
) {

0 commit comments

Comments
 (0)