forked from peschenbach/evalXai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
76 lines (68 loc) · 2.74 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
dependencies = ['torch']
# model implementations
from worker.training.models import CNN, LLR, MLP, CNN8by8, MLP8by8
# base URL for pretrained models - we should move this somewhere else
pretrained_url_base = 'https://github.com/braindatalab/exact/tree/main/worker/ai_model/'
def load_cnn(pretrained=False, n_dim=64, linear_dim=4, *args, **kwargs):
"""
Load the CNN model.
Args:
pretrained (bool): If True, loads a model pre-loaded with weights.
n_dim (int): Dimensionality parameter for CNN.
linear_dim (int): Dimensionality for linear layers in CNN.
"""
model = CNN(n_dim=n_dim, linear_dim=linear_dim, *args, **kwargs)
if pretrained:
checkpoint = pretrained_url_base + 'cnn.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint))
return model
def load_llr(pretrained=False, n_dim=64, *args, **kwargs):
"""
Load the LLR model.
Args:
pretrained (bool): If True, loads a model pre-loaded with weights.
n_dim (int): Dimensionality parameter for LLR.
"""
model = LLR(n_dim=n_dim, *args, **kwargs)
if pretrained:
checkpoint = pretrained_url_base + 'llr.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint))
return model
def load_mlp(pretrained=False, n_dim=64, *args, **kwargs):
"""
Load the MLP model.
Args:
pretrained (bool): If True, loads a model pre-loaded with weights.
n_dim (int): Dimensionality parameter for MLP.
"""
model = MLP(n_dim=n_dim, *args, **kwargs)
if pretrained:
checkpoint = pretrained_url_base + 'mlp.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint))
return model
def load_cnn8by8(pretrained=False, n_dim=64, linear_dim=4, *args, **kwargs):
"""
Load the CNN8by8 model.
Args:
pretrained (bool): If True, loads a model pre-loaded with weights.
n_dim (int): Dimensionality parameter for CNN8by8.
linear_dim (int): Dimensionality for linear layers in CNN8by8.
"""
model = CNN8by8(n_dim=n_dim, linear_dim=linear_dim, *args, **kwargs)
if pretrained:
checkpoint = pretrained_url_base + 'cnn8by8.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint))
return model
def load_mlp8by8(pretrained=False, n_dim=64, *args, **kwargs):
"""
Load the MLP8by8 model.
Args:
pretrained (bool): If True, loads a model pre-loaded with weights.
n_dim (int): Dimensionality parameter for MLP8by8.
"""
model = MLP8by8(n_dim=n_dim, *args, **kwargs)
if pretrained:
checkpoint = pretrained_url_base + 'mlp8by8.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint))
return model