Skip to content

Commit a66c4da

Browse files
committed
Object filling method speed up, using cupy for GPU boost.
1 parent f5761f4 commit a66c4da

File tree

4 files changed

+119
-76
lines changed

4 files changed

+119
-76
lines changed

create_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
def_maps=True,
1111
fill=True,
1212
canny=False,
13-
fill_texture=True)
13+
fill_tex=True)
1414

1515
create_dataset_norm_data(dataset_name)

dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from data_preparation.normalization import Normalization
88
from data_preparation.filling import fill_depth_map
99
from metrics import get_canny_mask
10-
from object_filling import fill_depth_map, fill_hr_texture
10+
from object_filling import fill_depth_map, fill_texture
1111

1212

1313
class DepthMapSRDataset(Dataset):
@@ -85,7 +85,7 @@ def __getitem__(self, idx):
8585
return sample[0], sample[1], sample[2], sample[3], sample[4]#, sample[5]
8686

8787

88-
def create_dataset(name, hr_dir, lr_dir, textures_dir, scale_lr=True, fill=True, def_maps=False, canny=False, fill_texture=False):
88+
def create_dataset(name, hr_dir, lr_dir, textures_dir, scale_lr=True, fill=True, def_maps=False, canny=False, fill_tex=False):
8989
print("--- CREATE DATASET: " + name + " ---")
9090
data = dict()
9191
hr_depth_maps = None
@@ -164,10 +164,10 @@ def create_dataset(name, hr_dir, lr_dir, textures_dir, scale_lr=True, fill=True,
164164
idx += 1
165165
data['tx'] = textures
166166

167-
if fill_texture:
167+
if fill_tex:
168168
for i in range(len(data["hr"])):
169169
print("> Augmenting HR texture " + str(i + 1) + "/" + str(len(data["hr"])))
170-
data['tx'][i] = fill_hr_texture(data['tx'][i], def_maps[i])
170+
data['tx'][i] = fill_texture(data['tx'][i], def_maps[i])
171171

172172
if canny:
173173
canny_masks = np.empty((len(data["hr"]), data["hr"][0].shape[0], data["hr"][0].shape[1]), dtype=float)

filling_inference.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from dataset import DepthMapSRDataset
2+
from torch.utils.data import DataLoader
3+
from evaluation.pointcloud import *
4+
import torch
5+
import matplotlib.pyplot as plt
6+
import matplotlib as mpl
7+
from object_filling import fill_depth_map, fill_texture
8+
import cupy as cp
9+
10+
device = "cuda" if torch.cuda.is_available() else "cpu"
11+
12+
dataset_name = 'NEAREST-LED-WARIOR-scale_4-ACTUAL'
13+
norm_file_path = 'dataset/' + dataset_name + '_norm.npy'
14+
15+
assert os.path.isfile(norm_file_path), "Normalization file for dataset '" + dataset_name + "' does not exist"
16+
norm_data = np.load(norm_file_path, allow_pickle=True).tolist()
17+
depth_max = norm_data["depth"][1]
18+
19+
dataset = DepthMapSRDataset(dataset_name, train=False, task='depth_map_sr', norm=False)
20+
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
21+
22+
lr_depth_map, texture, hr_depth_map, def_map, object_mask = next(iter(dataloader))
23+
unfilled_hr_depth_map = (hr_depth_map * def_map).clone()
24+
25+
img = np.array((hr_depth_map * def_map)[0][0].float().numpy())
26+
tex = np.array((texture / torch.max(texture))[0][0].float().numpy())
27+
filled, object_map = fill_depth_map(img, depth_max)
28+
filled_texture = fill_texture(tex, def_map[0][0].float().numpy())
29+
30+
cmap = mpl.cm.get_cmap("winter").copy()
31+
cmap.set_under(color='black')
32+
33+
hr_pcl_unfilled = PointCloud(unfilled_hr_depth_map[0][0].numpy())
34+
hr_pcl_unfilled.create_ply("UNFILLED-PLANE-hr-ptcloud-actual")
35+
36+
hr_pcl = PointCloud(cp.asnumpy(filled))
37+
hr_pcl.create_ply("FILLED-PLANE-hr-ptcloud-actual")
38+
39+
plt.figure(plt.figure('HR Depth map UNFILLED'))
40+
plt.imshow(unfilled_hr_depth_map[0][0], cmap=cmap, vmin=0.0001)
41+
42+
plt.figure(plt.figure('HR Depth map'))
43+
plt.imshow(cp.asnumpy(filled), cmap=cmap, vmin=0.0001)
44+
45+
plt.figure(plt.figure('HR Filled Texture'))
46+
plt.imshow(cp.asnumpy(filled_texture), cmap='gray')
47+
48+
plt.show()

object_filling.py

Lines changed: 66 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,31 @@
55
import skimage.io, skimage.transform
66
import matplotlib.pyplot as plt
77
import matplotlib as mpl
8-
8+
import cupy as cp
99

1010
def get_plane(img):
1111
img_loader_x = skimage.io.imread('scanner/normalized_vectors_x_plus_half.tif')
1212
img_loader_y = skimage.io.imread('scanner/normalized_vectors_y_plus_half.tif')
13-
nv_x = skimage.img_as_float(img_loader_x) - 0.5
14-
nv_y = skimage.img_as_float(img_loader_y) - 0.5
13+
nv_x = cp.asarray(skimage.img_as_float(img_loader_x) - 0.5)
14+
nv_y = cp.asarray(skimage.img_as_float(img_loader_y) - 0.5)
1515

16-
img_mean = np.mean(img[img != 0])
17-
mean_error_map = np.power((img - img_mean), 2)
18-
mean_error_map = np.where(img > 0, mean_error_map, np.infty)
16+
img_mean = cp.mean(img[img != 0])
17+
mean_error_map = cp.power((img - img_mean), 2)
18+
mean_error_map = cp.where(img > 0, mean_error_map, cp.infty)
1919

20-
mean_sorted = np.dstack(np.unravel_index(np.argsort(mean_error_map.ravel()), img.shape))[0]
21-
depth_sorted = np.dstack(np.unravel_index(np.argsort(img.ravel()), img.shape))[0]
20+
mean_sorted = cp.dstack(cp.unravel_index(cp.argsort(mean_error_map.ravel()), img.shape))[0]
21+
depth_sorted = cp.dstack(cp.unravel_index(cp.argsort(img.ravel()), img.shape))[0]
2222
a = mean_sorted[:20000]
2323
b = depth_sorted
2424
acceptable_mean = mean_error_map[a[19999, 0], a[19999, 1]]
2525
acceptable_depth = img[b[440000, 0], b[440000, 1]]
2626

27-
img1 = np.where(img == 0, 1.0, img)
27+
img1 = cp.where(img == 0, 1.0, img)
2828
x = nv_x * img1
2929
y = nv_y * img1 * (-1.0)
3030
z = img1.copy()
3131

32-
mesh_points = np.empty((int(img.shape[0] / 20) * int(img.shape[1] / 20), 2))
32+
mesh_points = cp.empty((int(img.shape[0] / 20) * int(img.shape[1] / 20), 2))
3333
for i in range(int(img.shape[0] / 20)):
3434
for j in range(int(img.shape[1] / 20)):
3535
mesh_points[int(i * img.shape[1] / 20) + j, 0] = i * 20
@@ -40,15 +40,15 @@ def get_plane(img):
4040
for i in range(mesh_points.shape[0]):
4141
if mean_error_map[int(mesh_points[i, 0]), int(mesh_points[i, 1])] < acceptable_mean:
4242
if mean_points is None:
43-
mean_points = np.array([mesh_points[i]])
43+
mean_points = cp.array([mesh_points[i]])
4444
else:
45-
mean_points = np.vstack([mean_points, mesh_points[i]])
45+
mean_points = cp.vstack([mean_points, mesh_points[i]])
4646

4747
if img[int(mesh_points[i, 0]), int(mesh_points[i, 1])] > acceptable_depth:
4848
if depth_points is None:
49-
depth_points = np.array([mesh_points[i]])
49+
depth_points = cp.array([mesh_points[i]])
5050
else:
51-
depth_points = np.vstack([depth_points, mesh_points[i]])
51+
depth_points = cp.vstack([depth_points, mesh_points[i]])
5252

5353
mean_points = mean_points[mean_points[:, 1].argsort()]
5454
depth_points = depth_points[depth_points[:, 1].argsort()]
@@ -68,27 +68,27 @@ def get_plane(img):
6868
plt.show()
6969
'''
7070

71-
mean_X = np.empty((mean_points.shape[0]))
72-
mean_Y = np.empty((mean_points.shape[0]))
73-
mean_Z = np.empty((mean_points.shape[0]))
71+
mean_X = cp.empty((mean_points.shape[0]))
72+
mean_Y = cp.empty((mean_points.shape[0]))
73+
mean_Z = cp.empty((mean_points.shape[0]))
7474

7575
for i in range(mean_points.shape[0]):
7676
mean_X[i] = x[int(mean_points[i, 0]), int(mean_points[i, 1])]
7777
mean_Y[i] = y[int(mean_points[i, 0]), int(mean_points[i, 1])]
7878
mean_Z[i] = z[int(mean_points[i, 0]), int(mean_points[i, 1])]
7979

80-
real_mean_points = np.c_[mean_X, mean_Y, mean_Z]
80+
real_mean_points = cp.c_[mean_X, mean_Y, mean_Z]
8181

82-
depth_X = np.empty((depth_points.shape[0]))
83-
depth_Y = np.empty((depth_points.shape[0]))
84-
depth_Z = np.empty((depth_points.shape[0]))
82+
depth_X = cp.empty((depth_points.shape[0]))
83+
depth_Y = cp.empty((depth_points.shape[0]))
84+
depth_Z = cp.empty((depth_points.shape[0]))
8585

8686
for i in range(depth_points.shape[0]):
8787
depth_X[i] = x[int(depth_points[i, 0]), int(depth_points[i, 1])]
8888
depth_Y[i] = y[int(depth_points[i, 0]), int(depth_points[i, 1])]
8989
depth_Z[i] = z[int(depth_points[i, 0]), int(depth_points[i, 1])]
9090

91-
real_depth_points = np.c_[depth_X, depth_Y, depth_Z]
91+
real_depth_points = cp.c_[depth_X, depth_Y, depth_Z]
9292

9393
A = real_mean_points[0]
9494
B = real_depth_points[0]
@@ -107,61 +107,56 @@ def get_plane(img):
107107
plt.show()
108108
'''
109109

110-
normal = np.cross(v1, v2)
110+
normal = cp.cross(v1, v2)
111111

112-
d = -np.sum(normal * A)
112+
d = -cp.sum(normal * A)
113113
#plane = -(normal[0] * x + normal[1] * y + d) / normal[2]
114114
full_plane = -d / (normal[0] * nv_x + normal[1] * nv_y * (-1.0) + normal[2])
115-
img_dis = abs(normal[0] * x + + normal[1] * y + normal[2] * img + d) / np.sqrt(normal[0]**2 + normal[1]**2 + normal[2]**2)
116-
img_dis = np.where(img == 0, 0, img_dis)
115+
img_dis = abs(normal[0] * x + + normal[1] * y + normal[2] * img + d) / cp.sqrt(normal[0]**2 + normal[1]**2 + normal[2]**2)
116+
img_dis = cp.where(img == 0, 0, img_dis)
117117

118118
return full_plane, img_dis
119119

120120

121121
def fill_depth_map(img, background_value):
122122
mask = np.where(img > 0, 0, 1).astype(np.uint8)
123-
uint_mask = np.where(img > 0, 0, 255).astype(np.uint8)
124-
rgb_mask = np.dstack((uint_mask, uint_mask, uint_mask))
125123

126124
labels, holes = measure.label(mask, background=0, return_num=True)
127-
125+
img = cp.asarray(img)
126+
labels = cp.asarray(labels)
128127
print("Filling " + str(holes) + " holes")
129128
with alive_bar(holes) as bar:
130129
for i in range(1, holes + 1):
131130
label = labels == i
132-
hole_border = segmentation.mark_boundaries(np.zeros(rgb_mask.shape),
133-
label.astype(np.int32),
134-
(1, 0, 0),
135-
None,
136-
'outer')[:, :, 0].astype(float)
137-
138-
#hole_border1 = np.logical_or((np.cumsum(label.astype(np.int32), axis=0) == 1), (np.cumsum(label.astype(np.int32), axis=1) == 1))
139-
#hole_border2 = np.logical_or(np.flipud(np.flipud(label.astype(np.int32).cumsum(axis=1)) == 1), np.fliplr(np.fliplr(label.astype(np.int32).cumsum(axis=0)) == 1))
140-
#hole_border = np.logical_or(hole_border1, hole_border2).astype(float)
141-
142-
border_values = img * hole_border
143-
label_img_border_pixels = np.sum(label.astype(np.int32)[0, :]) + \
144-
np.sum(label.astype(np.int32)[label.shape[0] - 1, :]) + \
145-
np.sum(label.astype(np.int32)[1:(label.shape[1] - 1), 0]) + \
146-
np.sum(label.astype(np.int32)[1:(label.shape[1] - 1), label.shape[1] - 1])
131+
132+
label_img_border_pixels = cp.sum(label.astype(cp.int32)[0, :]) + \
133+
cp.sum(label.astype(cp.int32)[label.shape[0] - 1, :]) + \
134+
cp.sum(label.astype(cp.int32)[1:(label.shape[1] - 1), 0]) + \
135+
cp.sum(label.astype(cp.int32)[1:(label.shape[1] - 1), label.shape[1] - 1])
136+
137+
if label_img_border_pixels == 0:
138+
top = cp.roll(label, 1, axis=0)
139+
bottom = cp.roll(label, -1, axis=0)
140+
right = cp.roll(label, 1, axis=1)
141+
left = cp.roll(label, -1, axis=1)
142+
hole_border = cp.logical_or(cp.logical_or(top, bottom), cp.logical_or(right, left))
143+
border_values = img * hole_border
147144

148145
background = label_img_border_pixels > 0
149146

150147
if not background:
151-
mx = np.ma.masked_array(border_values, mask=border_values == 0)
152-
line_values = mx.max(1)
153-
hole = (labels == i).astype(float) * line_values[:, np.newaxis]
148+
line_values = border_values.max(1)
149+
hole = (labels == i).astype(float) * line_values[:, cp.newaxis]
154150
else:
155-
156151
hole = (labels == i).astype(float) * 0
157152

158153
bar()
159154

160-
img = np.where(hole > 0, hole, img)
155+
img = cp.where(hole > 0, hole, img)
161156

162157
plane, dis = get_plane(img)
163-
object_map = np.where(dis < 10, 0, 1.0)
164-
img = np.where(img == 0, background_value, img)
158+
object_map = cp.where(dis < 10, 0, 1.0)
159+
img = cp.where(img == 0, background_value, img)
165160

166161
'''
167162
cmap = mpl.cm.get_cmap("winter").copy()
@@ -175,39 +170,39 @@ def fill_depth_map(img, background_value):
175170

176171
return img, object_map
177172

178-
def fill_hr_texture(tex, def_map):
173+
def fill_texture(tex, def_map):
179174
mask = np.where(def_map > 0, 0, 1).astype(np.uint8)
180-
uint_mask = np.where(def_map > 0, 0, 255).astype(np.uint8)
181-
rgb_mask = np.dstack((uint_mask, uint_mask, uint_mask))
182175

183176
labels, holes = measure.label(mask, background=0, return_num=True)
184-
177+
tex = cp.asarray(tex)
178+
labels = cp.asarray(labels)
185179
print("Filling " + str(holes) + " holes")
186180
with alive_bar(holes) as bar:
187181
for i in range(1, holes + 1):
188182
label = labels == i
189183

190-
hole_border = segmentation.mark_boundaries(np.zeros(rgb_mask.shape),
191-
label.astype(np.int32),
192-
(1, 0, 0),
193-
None,
194-
'outer')[:, :, 0].astype(float)
195-
border_values = tex * hole_border
196-
label_img_border_pixels = np.sum(label.astype(np.int32)[0, :]) + \
197-
np.sum(label.astype(np.int32)[label.shape[0] - 1, :]) + \
198-
np.sum(label.astype(np.int32)[1:(label.shape[1] - 1), 0]) + \
199-
np.sum(label.astype(np.int32)[1:(label.shape[1] - 1), label.shape[1] - 1])
184+
label_img_border_pixels = cp.sum(label.astype(cp.int32)[0, :]) + \
185+
cp.sum(label.astype(cp.int32)[label.shape[0] - 1, :]) + \
186+
cp.sum(label.astype(cp.int32)[1:(label.shape[1] - 1), 0]) + \
187+
cp.sum(label.astype(cp.int32)[1:(label.shape[1] - 1), label.shape[1] - 1])
188+
189+
if label_img_border_pixels == 0:
190+
top = cp.roll(label, 1, axis=0)
191+
bottom = cp.roll(label, -1, axis=0)
192+
right = cp.roll(label, 1, axis=1)
193+
left = cp.roll(label, -1, axis=1)
194+
hole_border = cp.logical_or(cp.logical_or(top, bottom), cp.logical_or(right, left))
195+
border_values = tex * hole_border
200196

201-
background = label_img_border_pixels > 0 or np.sum((labels == i).astype(float)) > 1000
197+
background = label_img_border_pixels > 0
202198

203199
if not background:
204-
mx = np.ma.masked_array(border_values, mask=border_values == 0)
205-
line_values = mx.max(1)
206-
hole = (labels == i).astype(float) * line_values[:, np.newaxis]
200+
line_values = border_values.max(1)
201+
hole = (labels == i).astype(float) * line_values[:, cp.newaxis]
207202
else:
208-
hole = (labels == i).astype(float) * 1
203+
hole = (labels == i).astype(float) * 0
209204

210205
bar()
211206

212-
tex = np.where(hole > 0, hole, tex)
207+
tex = cp.where(hole > 0, hole, tex)
213208
return tex

0 commit comments

Comments
 (0)