@@ -77,9 +77,12 @@ int main(int argc, char *argv[]) {
77
77
78
78
// Finally convert it to a unique pointer dataloader
79
79
auto dataset_mapped = dataset.map (torch::data::transforms::Stack<>());
80
- auto data_loader = torch::data::make_data_loader (std::move (dataset_mapped), torch::data::DataLoaderOptions ().batch_size (1 ).workers (6 ));
80
+ auto sampler = torch::data::samplers::SequentialSampler (dataset.size ().value ());
81
+ auto options = torch::data::DataLoaderOptions ().enforce_ordering (true ).batch_size (1 ).workers (10 );
82
+ auto data_loader = torch::data::make_data_loader (std::move (dataset_mapped), sampler, options);
81
83
82
84
// Loop through our batches of training data
85
+ bool visualize = true ;
83
86
double loss_sum = 0.0 ;
84
87
size_t loss_ct = 0 ;
85
88
size_t batch_idx = 0 ;
@@ -108,66 +111,68 @@ int main(int argc, char *argv[]) {
108
111
std::cout << items_curr << " /" << items_total << " | loss = " << loss.item <float >() << " | loss_avg = " << loss_avg << " (" << loss_ct
109
112
<< " samples)" << std::endl;
110
113
111
- // Softmax the output to get our total class probabilities [N, classes, H, W]
112
- // Thus across all classes, our probabilities should sum to 1
113
- auto output_probs = torch::softmax (output, 1 );
114
-
115
- // Plot the first image, need to change to opencv format [H,W,C]
116
- // Note that we arg max the softmax network output, then need to add an dimension
117
- // We scale up the 0..1 range back to the 0..255 that opencv expects (later cast to int)
118
- torch::Tensor cv_input = 255.0 * batch.data [0 ].permute ({1 , 2 , 0 }).clone ().cpu ();
119
- torch::Tensor cv_label = batch.target [0 ].permute ({1 , 2 , 0 }).clone ().cpu ();
120
- torch::Tensor cv_output = torch::unsqueeze (output_probs[0 ].argmax (0 ), 0 ).permute ({1 , 2 , 0 }).clone ().cpu ();
121
-
122
- // Convert them all to 0..255 ranges
123
- cv_input = cv_input.to (torch::kInt8 );
124
- cv_label = cv_label.to (torch::kInt8 );
125
- cv_output = cv_output.to (torch::kInt8 );
126
-
127
- // Point the cv::Mats to the transformed locations in memory
128
- cv::Mat img_input (cv::Size ((int )cv_input.size (1 ), (int )cv_input.size (0 )), CV_8UC3, cv_input.data_ptr <int8_t >());
129
- cv::Mat img_label (cv::Size ((int )cv_label.size (1 ), (int )cv_label.size (0 )), CV_8UC1, cv_label.data_ptr <int8_t >());
130
- cv::Mat img_output (cv::Size ((int )cv_output.size (1 ), (int )cv_output.size (0 )), CV_8UC1, cv_output.data_ptr <int8_t >());
131
-
132
- // Convert labeled images to color
133
- cv::cvtColor (img_label, img_label, cv::COLOR_GRAY2BGR);
134
- cv::cvtColor (img_output, img_output, cv::COLOR_GRAY2BGR);
135
- // img_label = 255.0 / (double)n_classes * img_label;
136
- // img_output = 255.0 / (double)n_classes * img_output;
137
-
138
- // Change both to be colored like the comma10k
139
- img_label.forEach <cv::Vec3b>([&](cv::Vec3b &px, const int *pos) -> void { px = dataset.map_id2hex [(char )px[0 ]]; });
140
- img_output.forEach <cv::Vec3b>([&](cv::Vec3b &px, const int *pos) -> void { px = dataset.map_id2hex [(char )px[0 ]]; });
141
-
142
- // Finally stack and display in a window
143
- cv::Mat outimg1, outimg2, outimg3;
144
- cv::hconcat (img_input, img_label, outimg1);
145
- cv::hconcat (img_input, img_output, outimg2);
146
- cv::vconcat (outimg1, outimg2, outimg3);
147
- cv::imshow (" prediction" , outimg3);
148
-
149
- // Next we will visualize our probability distributions [N, classes, H, W]
150
- torch::Tensor cv_probs = output_probs[0 ].clone ().cpu ();
151
- cv_probs = cv_probs.to (torch::kFloat32 );
152
- cv::Mat outimg4 = cv::Mat (cv::Size (n_classes * (int )cv_input.size (1 ), (int )cv_input.size (0 )), CV_8UC3, cv::Scalar (0 , 0 , 0 ));
153
- assert ((size_t )output_probs.size (0 ) == 1 );
154
- assert ((size_t )cv_probs.size (0 ) == n_classes);
155
- for (int n = 0 ; n < (int )n_classes; n++) {
156
- cv::Mat imgtmp (cv::Size ((int )cv_probs.size (2 ), (int )cv_probs.size (1 )), CV_32FC1, cv_probs[n].data_ptr <float >());
157
- imgtmp = 255 * imgtmp;
158
- imgtmp.convertTo (imgtmp, CV_8UC1);
159
- cv::Mat imgtmp_color;
160
- cv::applyColorMap (imgtmp, imgtmp_color, cv::COLORMAP_JET);
161
- imgtmp_color.copyTo (outimg4 (cv::Rect (n * (int )cv_input.size (1 ), 0 , imgtmp.cols , imgtmp.rows )));
114
+ // Visualize if we need to
115
+ if (visualize) {
116
+ // Softmax the output to get our total class probabilities [N, classes, H, W]
117
+ // Thus across all classes, our probabilities should sum to 1
118
+ auto output_probs = torch::softmax (output, 1 );
119
+
120
+ // Plot the first image, need to change to opencv format [H,W,C]
121
+ // Note that we arg max the softmax network output, then need to add an dimension
122
+ // We scale up the 0..1 range back to the 0..255 that opencv expects (later cast to int)
123
+ torch::Tensor cv_input = 255.0 * batch.data [0 ].permute ({1 , 2 , 0 }).clone ().cpu ();
124
+ torch::Tensor cv_label = batch.target [0 ].permute ({1 , 2 , 0 }).clone ().cpu ();
125
+ torch::Tensor cv_output = torch::unsqueeze (output_probs[0 ].argmax (0 ), 0 ).permute ({1 , 2 , 0 }).clone ().cpu ();
126
+
127
+ // Convert them all to 0..255 ranges
128
+ cv_input = cv_input.to (torch::kInt8 );
129
+ cv_label = cv_label.to (torch::kInt8 );
130
+ cv_output = cv_output.to (torch::kInt8 );
131
+
132
+ // Point the cv::Mats to the transformed locations in memory
133
+ cv::Mat img_input (cv::Size ((int )cv_input.size (1 ), (int )cv_input.size (0 )), CV_8UC3, cv_input.data_ptr <int8_t >());
134
+ cv::Mat img_label (cv::Size ((int )cv_label.size (1 ), (int )cv_label.size (0 )), CV_8UC1, cv_label.data_ptr <int8_t >());
135
+ cv::Mat img_output (cv::Size ((int )cv_output.size (1 ), (int )cv_output.size (0 )), CV_8UC1, cv_output.data_ptr <int8_t >());
136
+
137
+ // Convert labeled images to color
138
+ cv::cvtColor (img_label, img_label, cv::COLOR_GRAY2BGR);
139
+ cv::cvtColor (img_output, img_output, cv::COLOR_GRAY2BGR);
140
+ // img_label = 255.0 / (double)n_classes * img_label;
141
+ // img_output = 255.0 / (double)n_classes * img_output;
142
+
143
+ // Change both to be colored like the comma10k
144
+ img_label.forEach <cv::Vec3b>([&](cv::Vec3b &px, const int *pos) -> void { px = dataset.map_id2hex [(char )px[0 ]]; });
145
+ img_output.forEach <cv::Vec3b>([&](cv::Vec3b &px, const int *pos) -> void { px = dataset.map_id2hex [(char )px[0 ]]; });
146
+
147
+ // Finally stack and display in a window
148
+ cv::Mat outimg1, outimg2, outimg3;
149
+ cv::hconcat (img_input, img_label, outimg1);
150
+ cv::hconcat (img_input, img_output, outimg2);
151
+ cv::vconcat (outimg1, outimg2, outimg3);
152
+ cv::imshow (" prediction" , outimg3);
153
+
154
+ // Next we will visualize our probability distributions [N, classes, H, W]
155
+ torch::Tensor cv_probs = output_probs[0 ].clone ().cpu ();
156
+ cv_probs = cv_probs.to (torch::kFloat32 );
157
+ cv::Mat outimg4 = cv::Mat (cv::Size (n_classes * (int )cv_input.size (1 ), (int )cv_input.size (0 )), CV_8UC3, cv::Scalar (0 , 0 , 0 ));
158
+ assert ((size_t )output_probs.size (0 ) == 1 );
159
+ assert ((size_t )cv_probs.size (0 ) == n_classes);
160
+ for (int n = 0 ; n < (int )n_classes; n++) {
161
+ cv::Mat imgtmp (cv::Size ((int )cv_probs.size (2 ), (int )cv_probs.size (1 )), CV_32FC1, cv_probs[n].data_ptr <float >());
162
+ imgtmp = 255 * imgtmp;
163
+ imgtmp.convertTo (imgtmp, CV_8UC1);
164
+ cv::Mat imgtmp_color;
165
+ cv::applyColorMap (imgtmp, imgtmp_color, cv::COLORMAP_JET);
166
+ imgtmp_color.copyTo (outimg4 (cv::Rect (n * (int )cv_input.size (1 ), 0 , imgtmp.cols , imgtmp.rows )));
167
+ }
168
+ cv::imshow (" uncertainties" , outimg4);
169
+ cv::waitKey (100 );
170
+
171
+ // Save to file for readme
172
+ // cv::imwrite("/home/patrick/github/segnet/docs/example_pred.png", outimg3);
173
+ // cv::imwrite("/home/patrick/github/segnet/docs/example_probs.png", outimg4);
174
+ // std::exit(EXIT_FAILURE);
162
175
}
163
- cv::imshow (" uncertainties" , outimg4);
164
- cv::waitKey (100 );
165
-
166
- // Save to file for readme
167
- // cv::imwrite("/home/patrick/github/segnet/docs/example_pred.png", outimg3);
168
- // cv::imwrite("/home/patrick/github/segnet/docs/example_probs.png", outimg4);
169
- // std::exit(EXIT_FAILURE);
170
-
171
176
batch_idx++;
172
177
}
173
178
}
0 commit comments