1
+ - import torch
2
+ import torch .nn as nn
3
+ import random
4
+
5
+
6
+ from ptflops import get_model_complexity_info
7
+ from torchsummary import summary
8
+ from torch .utils .data import DataLoader
9
+
10
+
11
+ from model .CSPDenseNet import CSPDenseNet
12
+ from util .FIATClassificationDataset import FIATClassificationDataset
13
+
14
+
15
+ USE_CUDA = torch .cuda .is_available () # GPU를 사용가능하면 True, 아니라면 False를 리턴
16
+ device = torch .device ("cuda" if USE_CUDA else "cpu" ) # GPU 사용 가능하면 사용하고 아니면 CPU 사용
17
+ print ("다음 기기로 학습합니다:" , device )
18
+
19
+
20
+ # for reproducibility
21
+ random .seed (777 )
22
+ torch .manual_seed (777 )
23
+ if device == 'cuda' :
24
+ torch .cuda .manual_seed_all (777 )
25
+
26
+
27
+ ## Hyper parameter
28
+ training_epochs = 40
29
+ batch_size = 5
30
+ target_accuracy = 0.99
31
+ learning_rate = 0.0003
32
+ accuracy_threshold = 0.5
33
+ ## Hyper parameter
34
+
35
+
36
+ model = CSPDenseNet (class_num = 4 ,
37
+ block_config = (6 , 12 , 24 , 16 ),
38
+ expansion_rate = 4 , ##Bottleneck exansion size
39
+ growth_rate = 12 ,
40
+ activation = torch .nn .SiLU ).to (device )
41
+
42
+
43
+
44
+ """
45
+ model = DenseNet(class_num=4,
46
+ num_init_features=12,
47
+ block_config=(6, 6, 6),
48
+ expansion_rate=4, ##Bottleneck exansion size
49
+ growth_rate=40).to(device)
50
+ """
51
+ print ('==== model info ====' )
52
+ summary (model , (3 , 224 , 224 ))
53
+ print ('====================' )
54
+
55
+ macs , params = get_model_complexity_info (model ,
56
+ (3 , 224 , 224 ),
57
+ as_strings = True ,
58
+ print_per_layer_stat = True , verbose = True )
59
+ print ('{:<30} {:<8}' .format ('Computational complexity: ' , macs ))
60
+ print ('{:<30} {:<8}' .format ('Number of parameters: ' , params ))
61
+
62
+
63
+ ## no Train Model Save
64
+
65
+ model .eval ()
66
+ compiled_model = torch .jit .script (model )
67
+ torch .jit .save (compiled_model , "C://Github//DeepLearningStudy//trained_model//FIAT(CSPDenseNet).pt" )
68
+
69
+ trace_input = torch .rand (1 , 3 , 224 , 224 ).to (device , dtype = torch .float32 )
70
+ trace_model = torch .jit .trace (model , trace_input )
71
+ torch .jit .save (trace_model , "C://Github//DeepLearningStudy//trained_model//FIAT(CSPDenseNet)_Trace.pt" )
72
+
73
+ ## no Train Model Save
74
+
75
+
76
+ datasets = FIATClassificationDataset ('C://Github//DeepLearningStudy//dataset//FIAT_dataset_food//' ,
77
+ label_height = 224 ,
78
+ label_width = 224 ,
79
+ isColor = True ,
80
+ isNorm = False )
81
+ data_loader = DataLoader (datasets , batch_size = batch_size , shuffle = True )
82
+
83
+
84
+ model .train ()
85
+ criterion = nn .BCELoss ()
86
+ optimizer = torch .optim .Adam (model .parameters (), lr = learning_rate )
87
+
88
+
89
+ for epoch in range (training_epochs ): # 앞서 training_epochs의 값은 15로 지정함.
90
+ avg_cost = 0
91
+ avg_acc = 0
92
+ total_batch = len (data_loader )
93
+
94
+ for X , Y in data_loader :
95
+ gpu_X = X .to (device )
96
+ gpu_Y = Y .to (device )
97
+
98
+ model .train ()
99
+ optimizer .zero_grad ()
100
+ hypothesis = model (gpu_X )
101
+ cost = criterion (hypothesis , gpu_Y )
102
+ cost .backward ()
103
+ avg_cost += (cost / total_batch )
104
+ optimizer .step ()
105
+
106
+ model .eval ()
107
+ prediction = model (gpu_X )
108
+ correct_prediction = torch .argmax (prediction , 1 ) == torch .argmax (gpu_Y , 1 )
109
+ accuracy = correct_prediction .float ().mean ()
110
+ avg_acc += (accuracy / total_batch )
111
+
112
+ print ('Epoch:' , '%04d' % (epoch + 1 ), 'cost =' , '{:.9f}' .format (avg_cost ), 'acc =' , '{:.9f}' .format (avg_acc ))
113
+ if avg_acc > target_accuracy :
114
+ break ;
115
+
116
+ ## no Train Model Save
117
+ model .eval ()
118
+ compiled_model = torch .jit .script (model )
119
+ torch .jit .save (compiled_model , "C://Github//DeepLearningStudy//trained_model//TRAIN_FIAT(CSPDenseNet).pt" )
120
+ ## no Train Model Save
121
+
122
+ print ('Learning finished' )
0 commit comments