-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_song.py
80 lines (66 loc) · 2.59 KB
/
infer_song.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import shutil
import sys
import torch
import monai
from argparse import ArgumentParser
import pytorch_lightning as pl
import helpers
from monai.data.utils import pad_list_data_collate
from ds import climain
from pytorch_lightning.loggers import TensorBoardLogger
import time
torch.multiprocessing.set_sharing_strategy('file_system')
sys.path.append(os.path.dirname(__file__))
from basetrain_song import benchmark_unet_2d
def infer():
parser = ArgumentParser()
parser = helpers.add_training_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = benchmark_unet_2d.add_model_specific_args(parser)
args = parser.parse_args()
logger = TensorBoardLogger(save_dir=os.path.join('.', 'lightning_logs', f'mode{args.datasetmode}'), name='my_test')
if args.ckpt == 'local':
modelslist = []
for root, dirs, files in os.walk(r"F:\Forschung\multiorganseg\good\onlyfresh"):
for file in files:
if file.endswith('.ckpt'):
modelslist.append(os.path.join(root, file))
args.ckpt = modelslist[0]
infer_ds, _, _ = climain(args.data_folder[0], Input_worker=4, mode='test', dataset_mode=5,clean=args.clean)
infer_loader = monai.data.DataLoader(
infer_ds,
shuffle=False,
batch_size=4,
num_workers=4,
pin_memory=torch.cuda.is_available(),
collate_fn=pad_list_data_collate
)
test = torch.load(args.ckpt)
datamode = test['hyper_parameters']['datasetmode']
loss_method = test['hyper_parameters']['loss']
args.loss=loss_method
print(args.loss)
args.infer_mode=datamode
# if args.loss!='CE':
# input()
model = benchmark_unet_2d.load_from_checkpoint(args.ckpt, hparams=vars(args))
start_time = time.time()
trainer = pl.Trainer(gpus=-1, logger=logger, precision=16)
trainer.test(model, infer_loader)
if os.path.exists('saved_images'):
newname = f'saved_images_mode{datamode}_{loss_method}'
if os.path.exists(newname):
shutil.rmtree(newname)
os.rename('saved_images', newname)
print('time:', time.time() - start_time) # time: 91.36657166481018
if __name__ == "__main__":
# root,dirs,files=os.walk('./mostoolkit/lightning_logs/version_65')
pl.seed_everything(1234)
infer()
# modelslist=[]
# for root,dirs,files in os.walk(r"F:\Forschung\multiorganseg\good\onlyfresh"):
# for file in files:
# if file.endswith('.ckpt'):
# modelslist.append(os.path.join(root, file))
# infer(modelslist[0],r'F:\Forschung\multiorganseg\data\train_2D')