Skip to content

Commit 4722a99

Browse files
feat(wgpu): add to_dtype kernel (#906)
* feat(wgpu): add to_dtype kernel * fix: add WebGPUNativeType * style: clippy fix --------- Co-authored-by: Corey Lowman <[email protected]>
1 parent e04dd4f commit 4722a99

File tree

6 files changed

+228
-20
lines changed

6 files changed

+228
-20
lines changed

dfdx-core/src/tensor/webgpu/device.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,36 @@ impl Webgpu {
247247
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
248248
self.cs_cache.read().get(&name).cloned()
249249
}
250+
/// Submit a command buffer to the GPU.
251+
///
252+
/// Note: Does not block until completion. If you need this, use
253+
/// `self.dev.poll(Maintain::WaitForSubmissionIndex(idx))` using the
254+
/// returned [`wgpu::SubmissionIndex`]
255+
pub(crate) fn submit_commands<F>(
256+
&self,
257+
label: Option<&str>,
258+
command_builder: F,
259+
) -> wgpu::SubmissionIndex
260+
where
261+
F: FnOnce(&mut wgpu::CommandEncoder),
262+
{
263+
let mut encoder = self
264+
.dev
265+
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
266+
label: label.clone(),
267+
});
268+
269+
if let Some(label) = label {
270+
encoder.push_debug_group(label);
271+
}
272+
command_builder(&mut encoder);
273+
if labe.is_some() {
274+
encoder.pop_debug_group();
275+
}
276+
277+
let cmd = [encoder.finish()];
278+
self.queue.submit(cmd)
279+
}
250280

251281
// #[allow(unused)]
252282
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {

dfdx-core/src/tensor/webgpu/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
mod allocate;
22
mod device;
3+
mod types;
34

45
pub use device::Buffer;
56
pub use device::Webgpu;
7+
pub use types::*;
68

79
#[cfg(test)]
810
mod tests {

dfdx-core/src/tensor/webgpu/types.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use crate::shapes::Unit;
2+
3+
/// A primitive data type natively supported by WebGPU.
4+
///
5+
/// See: https://www.w3.org/TR/WGSL/#types
6+
///
7+
/// todo: support packed types
8+
pub trait WebgpuNativeType: Unit {
9+
/// Name of the data type in WGSL.
10+
const NAME: &'static str;
11+
}
12+
13+
macro_rules! webgpu_type {
14+
($RustTy:ty) => {
15+
impl WebgpuNativeType for $RustTy {
16+
const NAME: &'static str = stringify!($RustTy);
17+
}
18+
};
19+
($RustTy:ty, $WgpuTy:expr) => {
20+
impl WebgpuNativeType for $RustTy {
21+
const NAME: &'static str = $WgpuTy;
22+
}
23+
};
24+
}
25+
26+
/*
27+
see:
28+
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16
29+
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64
30+
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16
31+
*/
32+
#[cfg(feature = "f16")]
33+
webgpu_type!(half::f16, "f16");
34+
webgpu_type!(f32);
35+
// todo: only enable when f64 feature is enabled
36+
#[cfg(feature = "f64")]
37+
webgpu_type!(f64);
38+
39+
#[cfg(feature = "i16")]
40+
webgpu_type!(i16);
41+
webgpu_type!(i32);
42+
43+
webgpu_type!(u32);
44+
webgpu_type!(bool);
45+
46+
pub(crate) trait HasGlslType {
47+
const TYPE: &'static str;
48+
}
49+
50+
impl HasGlslType for f32 {
51+
const TYPE: &'static str = "float";
52+
}
53+
54+
impl HasGlslType for f64 {
55+
const TYPE: &'static str = "double";
56+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
alias T = __SRC__;
2+
alias U = __DST__;
3+
4+
@group(0) @binding(0)
5+
var<storage, read> in: array<T>;
6+
7+
@group(0) @binding(1)
8+
var<storage, read_write> out: array<U>;
9+
10+
@compute @workgroup_size(1, 1, 1)
11+
fn main(
12+
@builtin(global_invocation_id) global_id: vec3<u32>
13+
) {
14+
let i = global_id.x;
15+
out[i] = U(in[i]);
16+
}
Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,102 @@
1-
use crate::prelude::{Unit, Webgpu};
1+
use crate::{
2+
prelude::Storage,
3+
tensor::webgpu::{Webgpu, WebgpuNativeType},
4+
tensor_ops::utilities::webgpu_kernels::webgpu_params,
5+
};
6+
use num_traits::AsPrimitive;
7+
use wgpu;
28

3-
impl<E1: Unit, E2: Unit> super::ToDtypeKernel<E1, E2> for Webgpu {
9+
/// kernel template
10+
const KERNEL: &'static str = include_str!("./to_dtype.wgsl");
11+
12+
const LAYOUT_DESC: wgpu::BindGroupLayoutDescriptor = wgpu::BindGroupLayoutDescriptor {
13+
label: Some("to-dtype"),
14+
entries: &[
15+
wgpu::BindGroupLayoutEntry {
16+
binding: 0,
17+
visibility: wgpu::ShaderStages::COMPUTE,
18+
ty: wgpu::BindingType::Buffer {
19+
ty: wgpu::BufferBindingType::Storage { read_only: true },
20+
has_dynamic_offset: false,
21+
min_binding_size: None,
22+
},
23+
count: None,
24+
},
25+
wgpu::BindGroupLayoutEntry {
26+
binding: 1,
27+
visibility: wgpu::ShaderStages::COMPUTE,
28+
ty: wgpu::BindingType::Buffer {
29+
ty: wgpu::BufferBindingType::Storage { read_only: false },
30+
has_dynamic_offset: false,
31+
min_binding_size: None,
32+
},
33+
count: None,
34+
},
35+
],
36+
};
37+
38+
impl<E1: WebgpuNativeType + AsPrimitive<E2>, E2: WebgpuNativeType> super::ToDtypeKernel<E1, E2>
39+
for Webgpu
40+
{
441
fn forward<S: crate::prelude::Shape>(
542
inp: crate::prelude::Tensor<S, E1, Self>,
643
) -> Result<crate::prelude::Tensor<S, E2, Self>, crate::prelude::Error> {
7-
todo!()
44+
let module_name = std::format!("convert_{}_to_{}", E1::NAME, E2::NAME);
45+
let label = Some(module_name.as_str());
46+
let device = inp.device;
47+
48+
let layout = device.dev.create_bind_group_layout(&LAYOUT_DESC);
49+
let shader_source: String = KERNEL
50+
.replace("__SRC__", E1::NAME)
51+
.replace("__DST__", E2::NAME);
52+
53+
// TODO: support WGSL shaders in device shader cache
54+
let source = wgpu::ShaderSource::Wgsl(shader_source.into());
55+
let shader_module = device
56+
.dev
57+
.create_shader_module(wgpu::ShaderModuleDescriptor {
58+
label: Some(shader_name),
59+
source,
60+
});
61+
let pipeline_layout = device
62+
.dev
63+
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
64+
label: label.clone(),
65+
bind_group_layouts: layouts,
66+
// todo: these are useful and we should use them if the adapter supports them
67+
push_constant_ranges: &push_constant_ranges,
68+
});
69+
70+
let pipeline = device
71+
.dev
72+
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
73+
label: label.clone(),
74+
layout: Some(&pipeline_layout),
75+
module: &shader_module,
76+
entry_point: fn_name,
77+
});
78+
79+
let numel = inp.shape.num_elements();
80+
let shape = inp.shape;
81+
let strides = shape.strides();
82+
let output = unsafe { device.alloc_empty::<E2>(numel) }?;
83+
84+
let params: wgpu::BindGroup = webgpu_params!(device, pipeline; inp.data, output);
85+
86+
let _idx = device.submit_commands(label.clone(), |encoder| {
87+
let (x, y, z) = *work_groups;
88+
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
89+
label: label.clone(),
90+
..Default::default()
91+
});
92+
// TODO: should this be called before the pass, as the pass is created, or before submission?
93+
pass.set_pipeline(&pipeline);
94+
pass.set_bind_group(0, &params, &[]);
95+
pass.dispatch_workgroups(numel as u32, 1, 1);
96+
});
97+
98+
// note: no need to sync here, buffer can remain on the gpu until to_array or to_vec gets called,
99+
// and those functions sync the device before mapping the buffer
100+
Ok(device.build_tensor(shape, strides, output))
8101
}
9102
}

dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,33 @@ use crate::{
66
use core::any::TypeId;
77
use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec};
88

9+
use wgpu::{
10+
BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages,
11+
};
12+
13+
/// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s.
14+
macro_rules! webgpu_params {
15+
($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => {
16+
{
17+
let bindings = [$($x.as_entire_binding()),+];
18+
let entries: Vec<_> = bindings
19+
.into_iter()
20+
.enumerate()
21+
.map(|(i, binding)| wgpu::BindGroupEntry {
22+
binding: i as u32,
23+
resource: binding,
24+
})
25+
.collect();
26+
$self.dev.create_bind_group(&::wgpu::BindGroupDescriptor {
27+
label: None,
28+
layout: &($pipeline).get_bind_group_layout(0),
29+
entries: &entries
30+
})
31+
}
32+
}
33+
}
34+
pub(crate) use webgpu_params;
35+
936
pub(crate) trait UnaryOpWebgpuKernel<E> {
1037
const DF_USES_FX: bool;
1138
const HAS_CONST_DF: bool;
@@ -49,6 +76,7 @@ macro_rules! webgpu_unary {
4976
}
5077
};
5178
}
79+
pub(crate) use webgpu_unary;
5280

5381
/// Zero-sized marker type for forward pass TypeId
5482
#[derive(Debug, Default)]
@@ -62,23 +90,6 @@ pub(crate) struct Backward<E: Dtype, K> {
6290
_phantom: PhantomData<(E, K)>,
6391
}
6492

65-
pub(crate) trait HasGlslType {
66-
const TYPE: &'static str;
67-
}
68-
69-
impl HasGlslType for f32 {
70-
const TYPE: &'static str = "float";
71-
}
72-
73-
impl HasGlslType for f64 {
74-
const TYPE: &'static str = "double";
75-
}
76-
77-
pub(crate) use webgpu_unary;
78-
use wgpu::{
79-
BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages,
80-
};
81-
8293
impl<E: Dtype + HasGlslType, K: UnaryOpWebgpuKernel<E> + 'static> UnaryKernel<K, E> for Webgpu {
8394
const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX;
8495
const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF;

0 commit comments

Comments
 (0)