Skip to content

Commit 9d36560

Browse files
authored
Merge pull request #2 from ins-amu/develop
Develop
2 parents 85e51d9 + ec3f859 commit 9d36560

File tree

8 files changed

+178
-32
lines changed

8 files changed

+178
-32
lines changed

.github/workflows/tests.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: Test
2+
3+
on: [push]
4+
5+
jobs:
6+
test:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: actions/checkout@v2
10+
- name: Set up Python
11+
uses: actions/setup-python@v2
12+
with:
13+
python-version: "3.10"
14+
- name: Install dependencies
15+
run: |
16+
python -m pip install .
17+
18+
- name: Compile C++ code
19+
run: |
20+
cd vbi/models/cpp/_src
21+
make
22+
- name: Run tests
23+
run: |
24+
python -m pytest

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ dependencies = [
3939
"parameterized",
4040
"scikit-learn",
4141
"pycatch22",
42-
"pytest"
42+
"pytest",
43+
"swig"
4344
]
4445
classifiers = [
4546
"Programming Language :: Python :: 3",

vbi/models/cpp/mpr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import numpy as np
22
from typing import Union
33
from copy import deepcopy
4-
from vbi.models.cpp._src.mpr_sde import MPR_sde as _MPR_sde
5-
from vbi.models.cpp._src.mpr_sde import BoldParams as _BoldParams
4+
5+
try:
6+
from vbi.models.cpp._src.mpr_sde import MPR_sde as _MPR_sde
7+
from vbi.models.cpp._src.mpr_sde import BoldParams as _BoldParams
8+
except ImportError as e:
9+
print(f"Could not import modules: {e}, probably C++ code is not compiled.")
610

711

812
class MPR_sde:

vbi/models/cupy/bold.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import numpy as np
2+
3+
4+
class BoldStephan2008:
5+
6+
def __init__(self, par: dict = {}) -> None:
7+
8+
self._par = self.get_default_parameters()
9+
self.valid_parameters = list(self._par.keys())
10+
self.check_parameters(par)
11+
self._par.update(par)
12+
13+
for key, value in self._par.items():
14+
setattr(self, key, value)
15+
16+
def _prepare(self, nn, ns, xp, n_steps, bold_decimate):
17+
s = xp.zeros((2, nn, ns), dtype=self.dtype)
18+
f = xp.zeros((2, nn, ns), dtype=self.dtype)
19+
ftilde = xp.zeros((2, nn, ns), dtype=self.dtype)
20+
vtilde = xp.zeros((2, nn, ns), dtype=self.dtype)
21+
qtilde = xp.zeros((2, nn, ns), dtype=self.dtype)
22+
v = xp.zeros((2, nn, ns), dtype=self.dtype)
23+
q = xp.zeros((2, nn, ns), dtype=self.dtype)
24+
vv = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
25+
qq = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
26+
s[0] = 1
27+
f[0] = 1
28+
v[0] = 1
29+
q[0] = 1
30+
ftilde[0] = 0
31+
vtilde[0] = 0
32+
qtilde[0] = 0
33+
34+
return {
35+
"s": s,
36+
"f": f,
37+
"ftilde": ftilde,
38+
"vtilde": vtilde,
39+
"qtilde": qtilde,
40+
"v": v,
41+
"q": q,
42+
"vv": vv,
43+
"qq": qq,
44+
}
45+
46+
def check_parameters(self, par):
47+
for key in par.keys():
48+
if key not in self.valid_parameters:
49+
raise ValueError(f"Invalid parameter {key:s} provided.")
50+
51+
def get_default_parameters(self):
52+
53+
theta0 = 41.0
54+
Eo = 0.42
55+
TE = 0.05
56+
epsilon = 0.36
57+
r0 = 26.0
58+
k1 = 4.3 * theta0 * Eo * TE
59+
k2 = epsilon * r0 * Eo * TE
60+
k3 = 1 - epsilon
61+
62+
par = {
63+
"kappa": 0.7,
64+
"gamma": 0.5,
65+
"tau": 1.0,
66+
"alpha": 0.35,
67+
"epsilon": epsilon,
68+
"Eo": Eo,
69+
"TE": TE,
70+
"vo": 0.09,
71+
"r0": r0,
72+
"theta0": theta0,
73+
"rtol": 1e-6,
74+
"atol": 1e-9,
75+
"k1": k1,
76+
"k2": k2,
77+
"k3": k3,
78+
}
79+
return par
80+
81+
def bold_step(self, r_in, s, f, ftilde, vtilde, qtilde, v, q, dt, P):
82+
83+
kappa, gamma, alpha, tau, Eo = P
84+
ialpha = 1 / alpha
85+
86+
s[1] = s[0] + dt * (r_in - kappa * s[0] - gamma * (f[0] - 1))
87+
f[0] = np.clip(f[0], 1, None)
88+
ftilde[1] = ftilde[0] + dt * (s[0] / f[0])
89+
fv = v[0] ** ialpha # outflow
90+
vtilde[1] = vtilde[0] + dt * ((f[0] - fv) / (tau * v[0]))
91+
q[0] = np.clip(q[0], 0.01, None)
92+
ff = (1 - (1 - Eo) ** (1 / f[0])) / Eo # oxygen extraction
93+
qtilde[1] = qtilde[0] + dt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
94+
95+
f[1] = np.exp(ftilde[1])
96+
v[1] = np.exp(vtilde[1])
97+
q[1] = np.exp(qtilde[1])
98+
99+
f[0] = f[1]
100+
s[0] = s[1]
101+
ftilde[0] = ftilde[1]
102+
vtilde[0] = vtilde[1]
103+
qtilde[0] = qtilde[1]
104+
v[0] = v[1]
105+
q[0] = q[1]
106+
107+
108+
class BoldTVB:
109+
110+
def __init__(self):
111+
pass

vbi/models/cupy/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
2-
from numpy.matlib import repmat
2+
# from numpy.matlib import repmat
33

44
try:
55
import cupy as cp
66
except:
7-
pass
7+
cp = None
88

99

1010
def get_module(engine="gpu"):
@@ -20,7 +20,7 @@ def get_module(engine="gpu"):
2020

2121
def tohost(x):
2222
'''
23-
move data to cpu
23+
move data to cpu if it is on gpu
2424
2525
Parameters
2626
----------
@@ -32,7 +32,9 @@ def tohost(x):
3232
array
3333
data moved to cpu
3434
'''
35-
return cp.asnumpy(x)
35+
if cp is not None and isinstance(x, cp.ndarray):
36+
return cp.asnumpy(x)
37+
return x
3638

3739

3840
def todevice(x):
@@ -79,7 +81,7 @@ def repmat_vec(vec, ns, engine):
7981
repeated vector
8082
8183
'''
82-
vec = repmat(vec, ns, 1).T
84+
vec = np.tile(vec, (ns, 1)).T
8385
vec = move_data(vec, engine)
8486
return vec
8587

vbi/tests/test_module2.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

vbi/tests/test_mpr_cpp.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,43 @@
22
import unittest
33
import numpy as np
44
import networkx as nx
5-
from vbi.models.cpp.mpr import MPR_sde
65

7-
seed = 2
8-
np.random.seed(seed)
9-
torch.manual_seed(seed)
6+
MPR_AVAILABLE = True
7+
try:
8+
from vbi.models.cpp.mpr import MPR_sde
9+
except ImportError:
10+
MPR_AVAILABLE = False
11+
12+
13+
SEED = 2
14+
np.random.seed(SEED)
15+
torch.manual_seed(SEED)
1016

1117
nn = 3
1218
g = nx.complete_graph(nn)
13-
sc = nx.to_numpy_array(g)/ 10.0
19+
sc = nx.to_numpy_array(g) / 10.0
1420

1521

22+
@unittest.skipIf(not MPR_AVAILABLE, "vbi.models.cpp.mpr.MPR_sde module not available")
1623
class testMPRSDE(unittest.TestCase):
17-
24+
1825
mpr = MPR_sde()
1926
p = mpr.get_default_parameters()
20-
p['weights'] = sc
21-
p['seed'] = seed
22-
p['t_cut'] = 0.01 * 60 * 1000
23-
p['t_end'] = 0.02 * 60 * 1000
24-
27+
p["weights"] = sc
28+
p["seed"] = SEED
29+
p["t_cut"] = 0.01 * 60 * 1000
30+
p["t_end"] = 0.02 * 60 * 1000
31+
2532
def test_invalid_parameter_raises_value_error(self):
2633
invalid_params = {"invalid_param": 42}
2734
with self.assertRaises(ValueError):
2835
MPR_sde(par=invalid_params)
2936

3037
def test_run(self):
31-
38+
3239
control = {"G": 0.1, "eta": -4.7}
3340
mpr = MPR_sde(self.p)
3441
sol = mpr.run(par=control)
35-
x = sol["x"]
36-
t = sol["t"]
42+
x = sol["bold_d"]
43+
t = sol["bold_t"]
3744
self.assertEqual(x.shape[0], nn)
38-
39-
40-

vbi/tests/_test_mpr_cupy.py renamed to vbi/tests/test_mpr_cupy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import numpy as np
44
import networkx as nx
55
from copy import deepcopy
6-
from vbi.models.cupy.mpr import MPR_sde
6+
7+
MPR_AVAILABLE = True
8+
try:
9+
from vbi.models.cupy.mpr import MPR_sde
10+
except ImportError:
11+
MPR_AVAILABLE = False
712

813

914
seed = 2
@@ -15,6 +20,7 @@
1520
sc = nx.to_numpy_array(g) / 10.0
1621

1722

23+
@unittest.skipIf(not MPR_AVAILABLE, "vbi.models.cupy.mpr.MPR_sde module not available")
1824
class testMPRSDE(unittest.TestCase):
1925

2026
mpr = MPR_sde()

0 commit comments

Comments
 (0)