Skip to content

Commit 52eec72

Browse files
committed
Prevent naga crashing on an aliased ray query.
1 parent dcada3d commit 52eec72

File tree

9 files changed

+378
-20
lines changed

9 files changed

+378
-20
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ Naga now infers the correct binding layout when a resource appears only in an as
7777
- Properly apply WGSL's automatic conversions to the arguments to texture sampling functions. By @jimblandy in [#7548](https://github.com/gfx-rs/wgpu/pull/7548).
7878
- Properly evaluate `abs(most negative abstract int)`. By @jimblandy in [#7507](https://github.com/gfx-rs/wgpu/pull/7507).
7979
- Generate vectorized code for `[un]pack4x{I,U}8[Clamp]` on SPIR-V and MSL 2.1+. By @robamler in [#7664](https://github.com/gfx-rs/wgpu/pull/7664).
80+
- Prevent aliased ray queries crashing naga when writing
81+
SPIR-V out. By @Vecvec in [#7759](https://github.com/gfx-rs/wgpu/pull/7759).
8082

8183
#### DX12
8284

naga/src/back/spv/ray.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl Writer {
5454
let scalar_type_id = self.get_f32_type_id();
5555
let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
5656

57-
let argument_type_id = self.get_ray_query_pointer_id(ir_module);
57+
let argument_type_id = self.get_ray_query_pointer_id();
5858

5959
let func_ty = self.get_function_type(LookupFunctionType {
6060
parameter_type_ids: vec![argument_type_id],

naga/src/back/spv/writer.rs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -301,25 +301,9 @@ impl Writer {
301301
self.get_pointer_type_id(base_id, class)
302302
}
303303

304-
pub(super) fn get_ray_query_pointer_id(&mut self, module: &crate::Module) -> Word {
305-
let rq_ty = module
306-
.types
307-
.get(&crate::Type {
308-
name: None,
309-
inner: crate::TypeInner::RayQuery {
310-
vertex_return: false,
311-
},
312-
})
313-
.or_else(|| {
314-
module.types.get(&crate::Type {
315-
name: None,
316-
inner: crate::TypeInner::RayQuery {
317-
vertex_return: true,
318-
},
319-
})
320-
})
321-
.expect("ray_query type should have been populated by the variable passed into this!");
322-
self.get_handle_pointer_type_id(rq_ty, spirv::StorageClass::Function)
304+
pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
305+
let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
306+
self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
323307
}
324308

325309
/// Return a SPIR-V type for a pointer to `resolution`.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
god_mode = true
2+
targets = "SPIRV | METAL | HLSL"
3+
4+
[msl]
5+
fake_missing_bindings = true
6+
lang_version = [2, 4]
7+
spirv_cross_compatibility = false
8+
zero_initialize_workgroup_memory = false
9+
10+
[hlsl]
11+
shader_model = "V6_5"
12+
fake_missing_bindings = true
13+
zero_initialize_workgroup_memory = true
14+
15+
[spv]
16+
version = [1, 4]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
alias rq = ray_query;
2+
3+
@group(0) @binding(0)
4+
var acc_struct: acceleration_structure;
5+
6+
@compute @workgroup_size(1)
7+
fn main_candidate() {
8+
let pos = vec3<f32>(0.0);
9+
let dir = vec3<f32>(0.0, 1.0, 0.0);
10+
11+
var rq: rq;
12+
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, pos, dir));
13+
let intersection = rayQueryGetCandidateIntersection(&rq);
14+
if (intersection.kind == RAY_QUERY_INTERSECTION_AABB) {
15+
rayQueryGenerateIntersection(&rq, 10.0);
16+
} else if (intersection.kind == RAY_QUERY_INTERSECTION_TRIANGLE) {
17+
rayQueryConfirmIntersection(&rq);
18+
} else {
19+
rayQueryTerminate(&rq);
20+
}
21+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
struct RayDesc_ {
2+
uint flags;
3+
uint cull_mask;
4+
float tmin;
5+
float tmax;
6+
float3 origin;
7+
int _pad5_0;
8+
float3 dir;
9+
int _end_pad_0;
10+
};
11+
12+
struct RayIntersection {
13+
uint kind;
14+
float t;
15+
uint instance_custom_data;
16+
uint instance_index;
17+
uint sbt_record_offset;
18+
uint geometry_index;
19+
uint primitive_index;
20+
float2 barycentrics;
21+
bool front_face;
22+
int _pad9_0;
23+
int _pad9_1;
24+
row_major float4x3 object_to_world;
25+
int _pad10_0;
26+
row_major float4x3 world_to_object;
27+
int _end_pad_0;
28+
};
29+
30+
RayDesc RayDescFromRayDesc_(RayDesc_ arg0) {
31+
RayDesc ret = (RayDesc)0;
32+
ret.Origin = arg0.origin;
33+
ret.TMin = arg0.tmin;
34+
ret.Direction = arg0.dir;
35+
ret.TMax = arg0.tmax;
36+
return ret;
37+
}
38+
39+
RaytracingAccelerationStructure acc_struct : register(t0);
40+
41+
RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 arg4, float3 arg5) {
42+
RayDesc_ ret = (RayDesc_)0;
43+
ret.flags = arg0;
44+
ret.cull_mask = arg1;
45+
ret.tmin = arg2;
46+
ret.tmax = arg3;
47+
ret.origin = arg4;
48+
ret.dir = arg5;
49+
return ret;
50+
}
51+
52+
RayIntersection GetCandidateIntersection(RayQuery<RAY_FLAG_NONE> rq) {
53+
RayIntersection ret = (RayIntersection)0;
54+
CANDIDATE_TYPE kind = rq.CandidateType();
55+
if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {
56+
ret.kind = 1;
57+
ret.t = rq.CandidateTriangleRayT();
58+
ret.barycentrics = rq.CandidateTriangleBarycentrics();
59+
ret.front_face = rq.CandidateTriangleFrontFace();
60+
} else {
61+
ret.kind = 3;
62+
}
63+
ret.instance_custom_data = rq.CandidateInstanceID();
64+
ret.instance_index = rq.CandidateInstanceIndex();
65+
ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();
66+
ret.geometry_index = rq.CandidateGeometryIndex();
67+
ret.primitive_index = rq.CandidatePrimitiveIndex();
68+
ret.object_to_world = rq.CandidateObjectToWorld4x3();
69+
ret.world_to_object = rq.CandidateWorldToObject4x3();
70+
return ret;
71+
}
72+
73+
[numthreads(1, 1, 1)]
74+
void main_candidate()
75+
{
76+
RayQuery<RAY_FLAG_NONE> rq_1;
77+
78+
float3 pos = (0.0).xxx;
79+
float3 dir = float3(0.0, 1.0, 0.0);
80+
rq_1.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir)));
81+
RayIntersection intersection = GetCandidateIntersection(rq_1);
82+
if ((intersection.kind == 3u)) {
83+
rq_1.CommitProceduralPrimitiveHit(10.0);
84+
return;
85+
} else {
86+
if ((intersection.kind == 1u)) {
87+
rq_1.CommitNonOpaqueTriangleHit();
88+
return;
89+
} else {
90+
rq_1.Abort();
91+
return;
92+
}
93+
}
94+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"main_candidate",
9+
target_profile:"cs_6_5",
10+
),
11+
],
12+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// language: metal2.4
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
struct _RayQuery {
7+
metal::raytracing::intersector<metal::raytracing::instancing, metal::raytracing::triangle_data, metal::raytracing::world_space_data> intersector;
8+
metal::raytracing::intersector<metal::raytracing::instancing, metal::raytracing::triangle_data, metal::raytracing::world_space_data>::result_type intersection;
9+
bool ready = false;
10+
};
11+
constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) {
12+
return ty==metal::raytracing::intersection_type::triangle ? 1 :
13+
ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0;
14+
}
15+
16+
struct RayDesc {
17+
uint flags;
18+
uint cull_mask;
19+
float tmin;
20+
float tmax;
21+
metal::float3 origin;
22+
metal::float3 dir;
23+
};
24+
struct RayIntersection {
25+
uint kind;
26+
float t;
27+
uint instance_custom_data;
28+
uint instance_index;
29+
uint sbt_record_offset;
30+
uint geometry_index;
31+
uint primitive_index;
32+
metal::float2 barycentrics;
33+
bool front_face;
34+
char _pad9[11];
35+
metal::float4x3 object_to_world;
36+
metal::float4x3 world_to_object;
37+
};
38+
39+
kernel void main_candidate(
40+
metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
41+
) {
42+
_RayQuery rq_1 = {};
43+
metal::float3 pos = metal::float3(0.0);
44+
metal::float3 dir = metal::float3(0.0, 1.0, 0.0);
45+
RayDesc _e12 = RayDesc {4u, 255u, 0.1, 100.0, pos, dir};
46+
rq_1.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
47+
rq_1.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
48+
rq_1.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
49+
rq_1.intersector.accept_any_intersection((_e12.flags & 4) != 0);
50+
rq_1.intersection = rq_1.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq_1.ready = true;
51+
RayIntersection intersection = RayIntersection {_map_intersection_type(rq_1.intersection.type), rq_1.intersection.distance, rq_1.intersection.user_instance_id, rq_1.intersection.instance_id, {}, rq_1.intersection.geometry_id, rq_1.intersection.primitive_id, rq_1.intersection.triangle_barycentric_coord, rq_1.intersection.triangle_front_facing, {}, rq_1.intersection.object_to_world_transform, rq_1.intersection.world_to_object_transform};
52+
if (intersection.kind == 3u) {
53+
return;
54+
} else {
55+
if (intersection.kind == 1u) {
56+
return;
57+
} else {
58+
rq_1.ready = false;
59+
return;
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)