Skip to content

Commit 06a98b9

Browse files
committed
Added custom modules
1 parent fe89cb6 commit 06a98b9

File tree

4 files changed

+46
-41
lines changed

4 files changed

+46
-41
lines changed

rhtorch/models/modules.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,9 @@
22
import torch
33
import torch.nn as nn
44
import pytorch_lightning as pl
5-
import rhtorch
65
import torchmetrics as tm
76
import math
8-
import pkgutil
9-
import os
10-
import sys
11-
import importlib
12-
13-
14-
def recursive_find_python_class(name, folder=None, current_module="rhtorch.models"):
15-
16-
# Set default search path to root modules
17-
if folder is None:
18-
folder = [os.path.join(rhtorch.__path__[0], 'models')]
19-
20-
tr = None
21-
for importer, modname, ispkg in pkgutil.iter_modules(folder):
22-
if not ispkg:
23-
m = importlib.import_module(current_module + '.' + modname)
24-
if hasattr(m, name):
25-
tr = getattr(m, name)
26-
break
27-
28-
if tr is None:
29-
for importer, modname, ispkg in pkgutil.iter_modules(folder):
30-
if ispkg:
31-
next_current_module = current_module + '.' + modname
32-
tr = recursive_find_python_class(name, folder=[os.path.join(
33-
folder[0], modname)], current_module=next_current_module)
34-
35-
if tr is not None:
36-
break
37-
38-
if tr is None:
39-
sys.exit(f"Could not find module {name}")
40-
41-
return tr
42-
7+
from rhtorch.utilities.modules import recursive_find_python_class
438

449
class LightningAE(pl.LightningModule):
4510
def __init__(self, hparams, in_shape=(2, 128, 128, 128)):

rhtorch/torch_training.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from rhtorch.models import modules
1414
from rhtorch.callbacks import plotting
1515
from rhtorch.config_utils import UserConfig
16-
16+
from rhtorch.utilities.modules import recursive_find_python_class
1717

1818
def main():
1919
import argparse
@@ -65,8 +65,8 @@ def main():
6565

6666
# define lightning module
6767
shape_in = data_train.data_shape_in
68-
module = getattr(modules, configs['module'])
69-
model = module(configs, shape_in)
68+
module = recursive_find_python_class( configs['module'] )
69+
model = module(configs, shape_in) # Should also be changed to custom arguments (**configs)
7070

7171
# transfer learning setup
7272
if 'pretrained generator' in configs and configs['pretrained_generator']:
@@ -79,8 +79,6 @@ def main():
7979
pretrained_model_path, hparams=configs, in_shape=shape_in, strict=False)
8080
elif pretrained_model_path.name.endswith(".pt"):
8181
ckpt = torch.load(pretrained_model_path)
82-
# OBS, the 'state_dict' is not set during save?
83-
# What if we are to save multiple models used later for pretrain? (e.g. a GAN with 3 networks?)
8482
pretrained_model = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt
8583
model.load_state_dict(pretrained_model, strict=False)
8684
else:
@@ -89,6 +87,7 @@ def main():
8987
raise FileNotFoundError(
9088
"Model file not found. Check your path in config file.")
9189

90+
# This should be more generic. What if we are to only freeze some of the layers?
9291
if 'freeze_encoder' in configs and configs['freeze_encoder']:
9392
message += "with frozen encoder."
9493
model.encoder.freeze()

rhtorch/utilities/__init__.py

Whitespace-only changes.

rhtorch/utilities/modules.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Tue May 25 19:40:15 2021
4+
5+
@author: clad0003
6+
"""
7+
8+
import rhtorch
9+
import pkgutil
10+
import os
11+
import sys
12+
import importlib
13+
14+
def recursive_find_python_class(name, folder=None, current_module="rhtorch.models"):
15+
16+
# Set default search path to root modules
17+
if folder is None:
18+
folder = [os.path.join(rhtorch.__path__[0], 'models')]
19+
20+
tr = None
21+
for importer, modname, ispkg in pkgutil.iter_modules(folder):
22+
if not ispkg:
23+
m = importlib.import_module(current_module + '.' + modname)
24+
if hasattr(m, name):
25+
tr = getattr(m, name)
26+
break
27+
28+
if tr is None:
29+
for importer, modname, ispkg in pkgutil.iter_modules(folder):
30+
if ispkg:
31+
next_current_module = current_module + '.' + modname
32+
tr = recursive_find_python_class(name, folder=[os.path.join(
33+
folder[0], modname)], current_module=next_current_module)
34+
35+
if tr is not None:
36+
break
37+
38+
if tr is None:
39+
sys.exit(f"Could not find module {name}")
40+
41+
return tr

0 commit comments

Comments
 (0)