-
Notifications
You must be signed in to change notification settings - Fork 0
/
basetrain_song.py
162 lines (125 loc) · 5.57 KB
/
basetrain_song.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import sys
import numpy as np
import torch
import logging
import pytorch_lightning as pl
# Model import
from models.BasicUnet import BasicUnet
from models.Unet_song import UNET
# Loss import
from loss import CELoss
import monai
from argparse import ArgumentParser
from data_module.song_dataset import Song_dataset_2d_with_CacheDataloder
import helpers as helpers
from pytorch_lightning.callbacks import ModelCheckpoint
from base_train_2D import BasetRAIN
from pytorch_lightning.loggers import TensorBoardLogger
torch.multiprocessing.set_sharing_strategy('file_system')
sys.path.append(os.path.dirname(__file__))
# This demo contains the training and test pipeline of a 2D U-Net for organ segmentation.
# All 2D pipelines inherit the 2D base pipeline class. For complete implementation of training
# pipeline pls see base_train_2D.py
class benchmark_unet_2d(BasetRAIN):
'''
CEloss target不用onehot,也就是Dice的target就是一层就好
Dice需要,也就是和prediction同size,并且根据公式应该对prediction做sigmod
'''
def __init__(self, hparams):
super().__init__(hparams)
self.model = BasicUnet(in_channels=1, out_channels=4, nfilters=32).cuda()
# self.model = UNET(in_channels=1, out_channels=4).cuda()
Loss_weights = [0.5, 1.0, 1.0, 1.0]
if hparams['loss'] == 'CE':
self.loss = CELoss(weight=Loss_weights)
print("CELoss will be used")
else:
self.loss = monai.losses.DiceLoss(to_onehot_y=True)
# self.loss =monai.losses.FocalLoss(gamma=2,to_onehot_y=True)
# self.loss= DiceLoss(weight=weights)
print("DiceLoss will be used")
self.save_hyperparameters()
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--loss', type=str, default='CE')
parser.add_argument('--clean', type=bool, default=False)
parser.add_argument('--infer_mode', type=str, default='Unknown')
return parser
class ContinueTrain(benchmark_unet_2d):
def __init__(self, hparams):
super().__init__(hparams)
self.model = benchmark_unet_2d.load_from_checkpoint(hparams['lastcheckpoint'])
# self.model.freeze()
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
# main function
def cli_main():
pl.seed_everything(1234)
# Get experiment id
# parse the arguments
# All pipelines should use python argparser for configuration, so that training is easier on cluster
parser = ArgumentParser()
parser = helpers.add_training_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = benchmark_unet_2d.add_model_specific_args(parser)
args = parser.parse_args()
print('resume:', args.resume)
# --resume
# False
# --lastcheckpoint
# F:\Forschung\pure\lightning_logs\mode6\my_model\version_157\checkpoints\last.ckpt
# --hpar
# F:\Forschung\pure\lightning_logs\mode6\my_model\version_157\hparams.yaml
# create the pipeline
# Ckpt callbacks
ckpt_callback = ModelCheckpoint(
# monitor='avg_iousummean',
monitor='valid_sum_iou',
save_top_k=2,
mode="max",
save_last=True,
filename='{epoch:02d}-{valid_iou:.3f}'
)
saved_path=os.path.join('.', 'lightning_logs', f'mode{args.datasetmode}',
f'{args.loss}_'+f'clean_{args.clean}_'+f'resume_{args.resume}')
logger = TensorBoardLogger(save_dir=saved_path, name='my_model')
# create trainer using pytorch_lightning
if args.resume:
print("Resume")
# net = benchmark_unet_2d(hparams=vars(args)).load_from_checkpoint(args.lastcheckpoint)
net = ContinueTrain(hparams=vars(args))
trainer = pl.Trainer.from_argparse_args(args, precision=16, check_val_every_n_epoch=2,
callbacks=[ckpt_callback], logger=logger
)
else:
net = benchmark_unet_2d(hparams=vars(args))
trainer = pl.Trainer.from_argparse_args(args, precision=16, check_val_every_n_epoch=2,
callbacks=[ckpt_callback], logger=logger)
logging.info(f'Manual logging starts. Model version: {trainer.logger.version}_mode{args.datasetmode}')
# configure data module
logging.info(f'dataset from {args.data_folder}')
dm = Song_dataset_2d_with_CacheDataloder(args.data_folder[0],
worker=args.worker,
batch_size=args.batch_size,
mode=args.datasetmode,
clean=args.clean)
# dm.setup(stage='fit')
trainer.fit(model=net, datamodule=dm)
logging.info("!!!!!!!!!!!!!!This is the end of the training!!!!!!!!!!!!!!!!!!!!!!")
print('THE END')
sys.exit()
if __name__ == "__main__":
cli_main()
#
# model_infer(models=glob.glob('.\\lightning_logs\\version_650051\\**\\*.ckpt', recursive=True),
# raw_dir='D:\\Data\\ct_data\\visceral_manual_seg\\test',
# tar_dir=None,
# batch_size=10)
# organ-wise analysis
# helpers.MOS_eval(pred_path="D:\\Chang\\MultiOrganSeg\\model_output\\benchmark_unet_2D\\10000081_ct\\10000081_ct_seg.nii.gz",
# gt_path="D:\\Data\\ct_data\\test\\10000081\\GroundTruth.nii.gz")
# model_debug()