-
|
Hi everyone, I'm new to the Burn framework and I'm trying to create a module where I need to mutate a tensor field during the forward pass. Here's what I'm trying to achieve: #[derive(Debug, Module)]
pub struct MyModule<B: Backend> {
tensors: Tensor<B, 1>, // <- I need this to be mutable
// ... other fields
}The Problem: My forward pass function needs to modify the What I've Tried:
Questions:
Any guidance would be greatly appreciated! Thanks in advance for your help. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
|
Beta Was this translation helpful? Give feedback.
forwardwith&mut selfand mutate internal fields directly. (e.g.self.tensors = other_computed_tensororself.tensors = self.tensors.slice_assign(slice, partial_update). But this won't work correctly in multi-threaded or multi-device training, where updates need to be thread-safe. For a general thread-safe approach, Burn hasRunningStateas used inBatchNorm.burn/crates/burn-core/src/module/param/running.rs
Lines 48 to 58 in 9ff428a