Skip to content

Commit 8e5962f

Browse files
authored
Merge pull request #335 from stevenabreu7/master
ensure vthr is numpy array, not scalar
2 parents 2cb5bcb + 49a4b35 commit 8e5962f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

snntorch/export_nir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]:
2727

2828
beta = module.beta.detach().numpy()
2929
vthr = module.threshold.detach().numpy()
30+
vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr
3031
tau_mem = dt / (1 - beta)
3132
r = tau_mem / dt
3233
v_leak = np.zeros_like(beta)
@@ -56,6 +57,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]:
5657
alpha = module.alpha.detach().numpy()
5758
beta = module.beta.detach().numpy()
5859
vthr = module.threshold.detach().numpy()
60+
vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr
5961

6062
tau_syn = dt / (1 - alpha)
6163
tau_mem = dt / (1 - beta)

0 commit comments

Comments
 (0)