Skip to content

Commit 5439341

Browse files
committed
no-std
1 parent 21c9f62 commit 5439341

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

dfdx-core/src/tensor/webgpu/device.rs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,13 @@ use crate::{
1313
};
1414

1515
#[cfg(feature = "no-std")]
16-
use spin::Mutex;
16+
use spin::{Mutex, RwLock};
1717

1818
use core::any::TypeId;
1919
#[cfg(not(feature = "no-std"))]
20-
use std::sync::Mutex;
20+
use std::sync::{Mutex, RwLock};
2121

22-
use std::{
23-
collections::HashMap,
24-
marker::PhantomData,
25-
sync::{Arc, RwLock},
26-
vec::Vec,
27-
};
22+
use std::{collections::HashMap, marker::PhantomData, sync::Arc, vec::Vec};
2823

2924
use super::allocate::round_to_buffer_alignment;
3025

@@ -209,22 +204,37 @@ impl Webgpu {
209204
Ok(data)
210205
}
211206

207+
#[cfg(not(feature = "no-std"))]
212208
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
213209
self.cs_cache.read().unwrap().contains_key(&name)
214210
}
215211

212+
#[cfg(feature = "no-std")]
213+
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
214+
self.cs_cache.read().contains_key(&name)
215+
}
216+
216217
pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) {
217218
let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor {
218219
label: None,
219220
source: wgpu::ShaderSource::Wgsl(source.into()),
220221
}));
222+
#[cfg(not(feature = "no-std"))]
221223
self.cs_cache.write().unwrap().insert(name, module);
224+
#[cfg(feature = "no-std")]
225+
self.cs_cache.write().insert(name, module);
222226
}
223227

228+
#[cfg(not(feature = "no-std"))]
224229
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
225230
self.cs_cache.read().unwrap().get(&name).cloned()
226231
}
227232

233+
#[cfg(feature = "no-std")]
234+
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
235+
self.cs_cache.read().get(&name).cloned()
236+
}
237+
228238
// #[allow(unused)]
229239
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {
230240
// let num_bytes_required = len * std::mem::size_of::<E>();
@@ -369,7 +379,7 @@ impl<E: Unit> Storage<E> for Webgpu {
369379
type Vec = CachableBuffer<E>;
370380

371381
fn try_alloc_len(&self, len: usize) -> Result<Self::Vec, Error> {
372-
let data = unsafe { self.alloc_empty::<E>(len) }?;
382+
let data = self.alloc_empty::<E>(len)?;
373383
Ok(CachableBuffer {
374384
dev: self.dev.clone(),
375385
queue: self.queue.clone(),

0 commit comments

Comments
 (0)