Skip to content

Commit

Permalink
Adding distance transform watershed as an instance segmentation method.
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdfgfthgr-fox committed Feb 10, 2025
1 parent 3f2dfef commit a8cdda8
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 12 deletions.
35 changes: 26 additions & 9 deletions Components/DataComponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
import time
import h5py
import multiprocessing
# import scipy
import tracemalloc
import pandas as pd
from multiprocessing import Pool, shared_memory, cpu_count
import skimage.morphology as morph
from scipy.ndimage import label
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from . import Augmentations as Aug
from . import MorphologicalFunctions as Morph

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -907,13 +908,13 @@ def predictions_to_final_img(predictions, meta_list, direc, hw_size=128, depth_s
stitched_volumes = stitch_output_volumes(tensor_list, meta_list, hw_size, depth_size, hw_overlap, depth_overlap)
del tensor_list
for volume in stitched_volumes:
array = np.asarray(volume[0])
array = np.where(array >= 0.5, 1, 0)
imageio.v3.imwrite(uri=f'{direc}/{volume[1]}', image=np.uint8(array))
array = volume[0].numpy()
array = np.where(array >= 0.5, 1, 0).astype(np.uint8)
imageio.v3.imwrite(uri=f'{direc}/{volume[1]}', image=array)


def predictions_to_final_img_instance(predictions, meta_list, direc, hw_size=128, depth_size=128, hw_overlap=16,
depth_overlap=16, pixel_reclaim=True):
depth_overlap=16, segmentation_mode='simple', dynamic=10, pixel_reclaim=True):
"""
Stitch the patches of prediction output from network and save it in the selected directory.\n
This one is for instance segmentation.
Expand All @@ -926,6 +927,8 @@ def predictions_to_final_img_instance(predictions, meta_list, direc, hw_size=128
depth_size (int): The depth of the patches. In pixels.
hw_overlap (int): The additional gain in height and width of the patches. In pixels.
depth_overlap (int): The additional gain in depth of the patches. In pixels.
segmentation_mode (str): If 'simple', will identify objects via simple connected component labelling. If 'watershed', will use a distance transform watershed instead, which is slower but yield much less under-segment.
dynamic (int): Dynamic of intensity for the search of regional minima in the distance transform image. Increasing its value will yield more object merges. Default: 10.
pixel_reclaim (bool): Whether to reclaim lost pixel during the instance segmentation, a slow process. Default: True
"""
tensor_list_p = [
Expand Down Expand Up @@ -955,7 +958,7 @@ def predictions_to_final_img_instance(predictions, meta_list, direc, hw_size=128
imageio.v3.imwrite(uri=f'{direc}/Pixels_{semantic[1]}', image=np.float16(semantic[0].numpy()))
imageio.v3.imwrite(uri=f'{direc}/Contour_{contour[1]}', image=np.float16(contour[0].numpy()))
print(f'Computing instance segmentation using contour data for {contour[1]}... Can take a while if the image is big.')
instance_array = instance_segmentation_simple(semantic[0], contour[0], pixel_reclaim=pixel_reclaim)
instance_array = instance_segmentation_simple(semantic[0], contour[0], mode=segmentation_mode, dynamic=dynamic, pixel_reclaim=pixel_reclaim)
imageio.v3.imwrite(uri=f'{direc}/Instance_{contour[1]}', image=instance_array)


Expand Down Expand Up @@ -1042,7 +1045,7 @@ def allocate_pixels_global(batch_indices_and_args):
segmentation_shared[z, y, x] = closest_object.item()


def instance_segmentation_simple(semantic_map, contour_map, size_threshold=10, pixel_reclaim=True, distance_threshold=1, batch_size=2048):
def instance_segmentation_simple(semantic_map, contour_map, size_threshold=10, mode='simple', dynamic=10, pixel_reclaim=True, distance_threshold=1, batch_size=2048):
"""
Using a semantic segmentation map and a contour map to separate touching objects and perform instance segmentation.
Pixels in touching areas are assigned to the closest object based on the largest proportion of pixels within 5 pixel distance to the pixel.
Expand All @@ -1051,6 +1054,8 @@ def instance_segmentation_simple(semantic_map, contour_map, size_threshold=10, p
semantic_map (torch.Tensor): The input semantic segmented map.
contour_map (torch.Tensor): The input contour segmented map.
size_threshold (int): The minimal size in pixel of each object. Object smaller than this will be removed.
mode (str): If 'simple', will identify objects via simple connected component labelling. If 'watershed', will use a distance transform watershed instead, which is slower but yield much less under-segment.
dynamic (int): Dynamic of intensity for the search of regional minima in the distance transform image. Increasing its value will yield more object merges. Default: 10.
pixel_reclaim (bool): Whether to reclaim lost pixel during the instance segmentation, a slow process. Default: True
distance_threshold (int): The radius in pixels to search for nearby pixels when allocating. Default: 1.
batch_size (int): Batch size for pixel reclaim. Default: 2048.
Expand Down Expand Up @@ -1081,8 +1086,20 @@ def instance_segmentation_simple(semantic_map, contour_map, size_threshold=10, p
[0, 1, 0],
[0, 0, 0]]
], dtype=np.byte)
# Connected Component Labelling
label(segmentation, structure=structure, output=segmentation)
if mode == 'simple':
label(segmentation, structure=structure, output=segmentation)
elif mode == 'watershed':
distance_map = Morph.inverter(Morph.chamferdistancetransform3duint16(segmentation))
marker = distance_map + dynamic
hmin = Morph.geodesicreconstructionbyerosion3d(marker, distance_map)
del marker
gc.collect()
hmin = morph.local_minima(hmin).astype(np.uint16)
label(hmin, output=hmin)
print("Starts watershed flooding...")
segmentation = Morph.watershed_3d(distance_map, markers=hmin, mask=segmentation)
del distance_map, hmin
gc.collect()
# Remove small segments
#segmentation = morphology.remove_small_objects(segmentation, min_size=size_threshold, connectivity=structure.numpy())

Expand Down
210 changes: 210 additions & 0 deletions Components/MorphologicalFunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import numpy as np
from numba import njit
from heapq import heappush, heappop


def geodesicreconstructionbyerosion3d(marker, mask):
result = np.maximum(marker, mask)
mod_if = True
print("Geodesic Reconstructing...")
while mod_if:
mod_if = False
result, mod_if = _forward_scan_c6(marker, mask, result, mod_if)
result, mod_if = _backward_scan_c6(marker, mask, result, mod_if)
return result


@njit
def _forward_scan_c6(marker, mask, result, mod_if):
for z in range(marker.shape[0]):
for y in range(marker.shape[1]):
for x in range(marker.shape[2]):
current_value = result[z, y, x]
min_value = current_value

if x > 0:
min_value = min(min_value, result[z, y, x-1])
if y > 0:
min_value = min(min_value, result[z, y-1, x])
if z > 0:
min_value = min(min_value, result[z-1, y, x])

min_value = max(min_value, mask[z, y, x])
if min_value < current_value:
result[z, y, x] = min_value
mod_if = True
return result, mod_if


@njit
def _backward_scan_c6(marker, mask, result, mod_if):
for z in range(marker.shape[0] - 1, -1, -1):
for y in range(marker.shape[1] - 1, -1, -1):
for x in range(marker.shape[2] - 1, -1, -1):
current_value = result[z, y, x]
min_value = current_value

if x < marker.shape[2] - 1:
min_value = min(min_value, result[z, y, x+1])
if y < marker.shape[1] - 1:
min_value = min(min_value, result[z, y+1, x])
if z < marker.shape[0] - 1:
min_value = min(min_value, result[z+1, y, x])

min_value = max(min_value, mask[z, y, x])
if min_value < current_value:
result[z, y, x] = min_value
mod_if = True
return result, mod_if


def chamferdistancetransform3duint16(img):
result = np.where(img > 0, np.iinfo(np.uint16).max, 0).astype(np.uint16)
result = _forward_scan_cham_c6(img, result)
result = _backward_scan_cham_c6(img, result)
return result

# Define the Borgefors weights and offsets
offsets = [
(1, 0, 0, 3),
(0, 1, 0, 3),
(0, 0, 1, 3),
(-1, 0, 0, 3),
(0, -1, 0, 3),
(0, 0, -1, 3),
(1, 1, 0, 4),
(1, -1, 0, 4),
(-1, 1, 0, 4),
(-1, -1, 0, 4),
(1, 0, 1, 4),
(1, 0, -1, 4),
(-1, 0, 1, 4),
(-1, 0, -1, 4),
(0, 1, 1, 4),
(0, 1, -1, 4),
(0, -1, 1, 4),
(0, -1, -1, 4),
(1, 1, 1, 5),
(1, 1, -1, 5),
(1, -1, 1, 5),
(1, -1, -1, 5),
(-1, -1, 1, 5),
(-1, 1, 1, 5),
(-1, 1, -1, 5),
(-1, -1, -1, 5),
]

@njit
def _forward_scan_cham_c6(img, result):
for z in range(img.shape[0]):
for y in range(img.shape[1]):
for x in range(img.shape[2]):
if img[z, y, x] == 0:
continue

current_value = result[z, y, x]
new_value = np.iinfo(np.uint16).max

# Iterate over the offsets
for dx, dy, dz, weight in offsets:
x2 = x + dx
y2 = y + dy
z2 = z + dz

# Check if the neighbor is within bounds
if 0 <= x2 < img.shape[2] and 0 <= y2 < img.shape[1] and 0 <= z2 < img.shape[0]:
neighbor_value = result[z2, y2, x2] + weight
new_value = min(new_value, neighbor_value)

# Update the current voxel if a smaller value was found
if new_value < current_value:
result[z, y, x] = new_value
return result


@njit
def _backward_scan_cham_c6(img, result):
for z in range(img.shape[0] - 1, -1, -1):
for y in range(img.shape[1] - 1, -1, -1):
for x in range(img.shape[2] - 1, -1, -1):
if img[z, y, x] == 0:
continue

current_value = result[z, y, x]
new_value = np.iinfo(np.uint16).max

# Iterate over the offsets
for dx, dy, dz, weight in offsets:
x2 = x + dx
y2 = y + dy
z2 = z + dz

# Check if the neighbor is within bounds
if 0 <= x2 < img.shape[2] and 0 <= y2 < img.shape[1] and 0 <= z2 < img.shape[0]:
neighbor_value = result[z2, y2, x2] + weight
new_value = min(new_value, neighbor_value)

# Update the current voxel if a smaller value was found
if new_value < current_value:
result[z, y, x] = new_value
return result


def __heapify_markers_3d(markers, image):
"""Create a priority queue heap with the markers on it for 3D."""
stride = np.array(image.strides, dtype=np.uint32) // image.itemsize
coords = np.argwhere(markers != 0).astype(np.uint32)
ncoords = coords.shape[0]
if ncoords > 0:
pixels = image[markers != 0]
age = np.arange(ncoords, dtype=np.uint32)
offset = np.zeros(coords.shape[0], dtype=np.uint32)
for i in range(image.ndim):
offset = offset + stride[i] * coords[:, i]
pq = [tuple(row) for row in np.column_stack((pixels, age, offset, coords))]
ordering = np.lexsort((age, pixels))
pq = [pq[i] for i in ordering]
else:
pq = np.zeros((0, markers.ndim + 3), int)
return (pq, ncoords)


@njit
def _watershed_loop(pq, labels, connect_increments, mask, image, age):
max_x, max_y, max_z = labels.shape
while len(pq):
pix_value, pix_age, _, pix_x, pix_y, pix_z = heappop(pq)
pix_label = labels[pix_x, pix_y, pix_z]

for dx, dy, dz in connect_increments:
x, y, z = pix_x + dx, pix_y + dy, pix_z + dz
if x < 0 or y < 0 or z < 0 or x >= max_x or y >= max_y or z >= max_z:
continue
if labels[x, y, z]:
continue
if mask is not None and not mask[x, y, z]:
continue

labels[x, y, z] = pix_label
new_pq_item = (np.uint32(image[x, y, z]), np.uint32(age), np.uint32(0), np.uint32(x), np.uint32(y), np.uint32(z))
heappush(pq, new_pq_item)
age += 1
return labels


# The "Slower" watershed taken from scikits-image. Is faster after using Numba.
def watershed_3d(image, markers, mask=None):
"""Watershed algorithm optimized with Numba for 3D images with 6-connectivity."""
connect_increments = [
(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)
]
pq, age = __heapify_markers_3d(markers, image)
print('Watersheding...')
return _watershed_loop(pq, markers, connect_increments, mask, image, age)


def inverter(img):
min = img.min()
max = img.max()
img = max - (img - min)
return img
12 changes: 11 additions & 1 deletion WebUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def start_work_flow(inputs):
f"--hw_size {inputs[hw_size]} --d_size {inputs[d_size]} "
f"--predict_hw_size {inputs[predict_hw_size]} --predict_depth_size {inputs[predict_depth_size]} "
f"--predict_hw_overlap {inputs[predict_hw_overlap]} --predict_depth_overlap {inputs[predict_depth_overlap]} "
f"--watershed_dynamic {inputs[watershed_dynamic]} "
f"--result_folder_path {inputs[result_folder_path]} "
f"--mid_visualization_input {inputs[mid_visualization_input]} "
f"--model_architecture {inputs[model_architecture]} --model_channel_count {inputs[model_channel_count]} "
Expand Down Expand Up @@ -108,6 +109,8 @@ def start_work_flow(inputs):
cmd += "--model_se "
if inputs[find_max_channel_count]:
cmd += "--find_max_channel_count "
if inputs[instance_seg_mode]:
cmd += "--instance_seg_mode "
if inputs[pixel_reclaim]:
cmd += "--pixel_reclaim "

Expand Down Expand Up @@ -640,9 +643,14 @@ def calculate_predict_parameters(model_depth, hw_size, d_size, dk):
info="Horizontal And Vertical flip the image; the augmented images are then passed into the model."
" Corresponding reverse transformation then applys to the output probability maps, and those maps get combined together."
" Can improve segmentation accuracy, but will take longer and consume more CPU memory.")'''
instance_seg_mode = gr.Checkbox(label="Use distance transform watershed for instance segmentation",
info="Use a slower and more memory intensive watershed method for seperate touching objects, "
"instead of simple connected component labelling. "
"Will yield result with much less under-segment objects.")
watershed_dynamic = gr.Number(10, label="Dynamic of intensity for the search of regional minima in the distance transform image. Increasing its value will yield more object merges.")
pixel_reclaim = gr.Checkbox(label="Enable Pixel reclaim operation for instance segmentation",
info="Due to how instance segmentation works, some pixels will be lost when seperating touching objects, "
"this settings will try to reclaim those lost pixels, but can take quite some time.")
"this settings will try to reclaim some of those lost pixels, but can take quite some time.")
#TTA_z = gr.Checkbox(label="Enable Test-Time Augmentation for z dimension", info="Depth Wise flip the image")
with gr.Row():
calculate_repeats = gr.Button(value="Calculate Training Repeats and Epoches!")
Expand Down Expand Up @@ -792,6 +800,8 @@ def show_hide_model_tab(read_existing_model, segmentation_mode):
val_dataset_mode,
test_dataset_mode,
# TTA_xy,
instance_seg_mode,
watershed_dynamic,
pixel_reclaim,
# TTA_z,
}
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ tensorboard==2.15.1
joblib==1.3.2
opencv-python==4.10.0.82
imagecodecs==2024.1.1
overrides==7.7.0
overrides==7.7.0
numba==0.59.1
13 changes: 12 additions & 1 deletion workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,15 @@ def find_max_channel(min_channel, max_channel):
hw_size=args.predict_hw_size, depth_size=args.predict_depth_size,
hw_overlap=args.predict_hw_overlap, depth_overlap=args.predict_depth_overlap)
else:
if args.instance_seg_mode:
mode = 'watershed'
else:
mode = 'simple'
DataComponents.predictions_to_final_img_instance(predictions, meta_info, direc=args.result_folder_path,
hw_size=args.predict_hw_size, depth_size=args.predict_depth_size,
hw_overlap=args.predict_hw_overlap, depth_overlap=args.predict_depth_overlap,
pixel_reclaim=args.pixel_reclaim)
segmentation_mode=mode, dynamic=args.dynamic,
pixel_reclaim=args.pixel_reclaim)
end_time = time.time()
print(f"Converting and saving taken: {end_time - start_time} seconds")

Expand Down Expand Up @@ -284,6 +289,8 @@ def find_max_channel(min_channel, max_channel):
parser.add_argument("--predict_depth_size", type=int, default=128, help="Depth of each Patch (px) during prediction")
parser.add_argument("--predict_hw_overlap", type=int, default=8,
help="Expansion in Height and Width for each Patch (px) during prediction")
parser.add_argument("--watershed_dynamic", type=int, default=10,
help="Dynamic of intensity for the search of regional minima in the distance transform image. Increasing its value will yield more object merges.")
parser.add_argument("--predict_depth_overlap", type=int, default=8, help="Expansion in Depth for each Patch (px) during prediction")
parser.add_argument("--result_folder_path", type=str, default="Datasets/result", help="Result Folder Path")
parser.add_argument("--enable_mid_visualization", action="store_true", help="Enable Visualization")
Expand All @@ -307,6 +314,10 @@ def find_max_channel(min_channel, max_channel):
parser.add_argument("--test_dataset_mode", choices=["Fully Labelled", "Sparsely Labelled"],
default="Fully Labelled", help="Dataset Mode")
#parser.add_argument("--TTA_xy", action="store_true", help="Enable Test-Time Augmentation for xy dimension")
parser.add_argument("--instance_seg_mode", action="store_true",
help="Use a slower and more memory intensive watershed method for seperate touching objects, "
"instead of simple connected component labelling. "
"Will yield result with much less under-segment objects.")
parser.add_argument("--pixel_reclaim", action="store_true", help="Enable reclaim of lost pixel during the instance segmentation.")

args = parser.parse_args()
Expand Down

0 comments on commit a8cdda8

Please sign in to comment.