|
| 1 | +# https://github.com/meidachen/STPLS3D/blob/main/HAIS/data/prepare_data_inst_instance_stpls3d.py |
| 2 | +import glob |
| 3 | +import json |
| 4 | +import math |
| 5 | +import os |
| 6 | +import random |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | +import torch |
| 11 | + |
| 12 | + |
| 13 | +def splitPointCloud(cloud, size=50.0, stride=50): |
| 14 | + limitMax = np.amax(cloud[:, 0:3], axis=0) |
| 15 | + width = int(np.ceil((limitMax[0] - size) / stride)) + 1 |
| 16 | + depth = int(np.ceil((limitMax[1] - size) / stride)) + 1 |
| 17 | + cells = [(x * stride, y * stride) for x in range(width) for y in range(depth)] |
| 18 | + blocks = [] |
| 19 | + for (x, y) in cells: |
| 20 | + xcond = (cloud[:, 0] <= x + size) & (cloud[:, 0] >= x) |
| 21 | + ycond = (cloud[:, 1] <= y + size) & (cloud[:, 1] >= y) |
| 22 | + cond = xcond & ycond |
| 23 | + block = cloud[cond, :] |
| 24 | + blocks.append(block) |
| 25 | + return blocks |
| 26 | + |
| 27 | + |
| 28 | +def getFiles(files, fileSplit): |
| 29 | + res = [] |
| 30 | + for filePath in files: |
| 31 | + name = os.path.basename(filePath) |
| 32 | + num = name[:2] if name[:2].isdigit() else name[:1] |
| 33 | + if int(num) in fileSplit: |
| 34 | + res.append(filePath) |
| 35 | + return res |
| 36 | + |
| 37 | + |
| 38 | +def dataAug(file, semanticKeep): |
| 39 | + points = pd.read_csv(file, header=None).values |
| 40 | + angle = random.randint(1, 359) |
| 41 | + angleRadians = math.radians(angle) |
| 42 | + rotationMatrix = np.array([[math.cos(angleRadians), -math.sin(angleRadians), 0], |
| 43 | + [math.sin(angleRadians), |
| 44 | + math.cos(angleRadians), 0], [0, 0, 1]]) |
| 45 | + points[:, :3] = points[:, :3].dot(rotationMatrix) |
| 46 | + pointsKept = points[np.in1d(points[:, 6], semanticKeep)] |
| 47 | + return pointsKept |
| 48 | + |
| 49 | + |
| 50 | +def preparePthFiles(files, split, outPutFolder, AugTimes=0): |
| 51 | + # save the coordinates so that we can merge the data to a single scene |
| 52 | + # after segmentation for visualization |
| 53 | + outJsonPath = os.path.join(outPutFolder, 'coordShift.json') |
| 54 | + coordShift = {} |
| 55 | + # used to increase z range if it is smaller than this, |
| 56 | + # over come the issue where spconv may crash for voxlization. |
| 57 | + zThreshold = 6 |
| 58 | + |
| 59 | + # Map relevant classes to {1,...,14}, and ignored classes to -100 |
| 60 | + remapper = np.ones(150) * (-100) |
| 61 | + for i, x in enumerate([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]): |
| 62 | + remapper[x] = i |
| 63 | + # Map instance to -100 based on selected semantic |
| 64 | + # (change a semantic to -100 if you want to ignore it for instance) |
| 65 | + remapper_disableInstanceBySemantic = np.ones(150) * (-100) |
| 66 | + for i, x in enumerate([-100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]): |
| 67 | + remapper_disableInstanceBySemantic[x] = i |
| 68 | + |
| 69 | + # only augment data for these classes |
| 70 | + semanticKeep = [0, 2, 3, 7, 8, 9, 12, 13] |
| 71 | + |
| 72 | + counter = 0 |
| 73 | + for file in files: |
| 74 | + |
| 75 | + for AugTime in range(AugTimes + 1): |
| 76 | + if AugTime == 0: |
| 77 | + points = pd.read_csv(file, header=None).values |
| 78 | + else: |
| 79 | + points = dataAug(file, semanticKeep) |
| 80 | + name = os.path.basename(file).strip('.txt') + '_%d' % AugTime |
| 81 | + |
| 82 | + if split != 'test': |
| 83 | + coordShift['globalShift'] = list(points[:, :3].min(0)) |
| 84 | + points[:, :3] = points[:, :3] - points[:, :3].min(0) |
| 85 | + |
| 86 | + blocks = splitPointCloud(points, size=50, stride=50) |
| 87 | + for blockNum, block in enumerate(blocks): |
| 88 | + if (len(block) > 10000): |
| 89 | + outFilePath = os.path.join(outPutFolder, |
| 90 | + name + str(blockNum) + '_inst_nostuff.pth') |
| 91 | + if (block[:, 2].max(0) - block[:, 2].min(0) < zThreshold): |
| 92 | + block = np.append( |
| 93 | + block, [[ |
| 94 | + block[:, 0].mean(0), block[:, 1].mean(0), block[:, 2].max(0) + |
| 95 | + (zThreshold - |
| 96 | + (block[:, 2].max(0) - block[:, 2].min(0))), block[:, 3].mean(0), |
| 97 | + block[:, 4].mean(0), block[:, 5].mean(0), -100, -100 |
| 98 | + ]], |
| 99 | + axis=0) |
| 100 | + print('range z is smaller than threshold ') |
| 101 | + print(name + str(blockNum) + '_inst_nostuff') |
| 102 | + if split != 'test': |
| 103 | + outFileName = name + str(blockNum) + '_inst_nostuff' |
| 104 | + coordShift[outFileName] = list(block[:, :3].mean(0)) |
| 105 | + coords = np.ascontiguousarray(block[:, :3] - block[:, :3].mean(0)) |
| 106 | + |
| 107 | + # coords = block[:, :3] |
| 108 | + colors = np.ascontiguousarray(block[:, 3:6]) / 127.5 - 1 |
| 109 | + |
| 110 | + coords = np.float32(coords) |
| 111 | + colors = np.float32(colors) |
| 112 | + if split != 'test': |
| 113 | + sem_labels = np.ascontiguousarray(block[:, -2]) |
| 114 | + sem_labels = sem_labels.astype(np.int32) |
| 115 | + sem_labels = remapper[np.array(sem_labels)] |
| 116 | + |
| 117 | + instance_labels = np.ascontiguousarray(block[:, -1]) |
| 118 | + instance_labels = instance_labels.astype(np.float32) |
| 119 | + |
| 120 | + disableInstanceBySemantic_labels = np.ascontiguousarray(block[:, -2]) |
| 121 | + disableInstanceBySemantic_labels = disableInstanceBySemantic_labels.astype( |
| 122 | + np.int32) |
| 123 | + disableInstanceBySemantic_labels = remapper_disableInstanceBySemantic[ |
| 124 | + np.array(disableInstanceBySemantic_labels)] |
| 125 | + instance_labels = np.where(disableInstanceBySemantic_labels == -100, -100, |
| 126 | + instance_labels) |
| 127 | + |
| 128 | + # map instance from 0. |
| 129 | + # [1:] because there are -100 |
| 130 | + uniqueInstances = (np.unique(instance_labels))[1:].astype(np.int32) |
| 131 | + remapper_instance = np.ones(50000) * (-100) |
| 132 | + for i, j in enumerate(uniqueInstances): |
| 133 | + remapper_instance[j] = i |
| 134 | + |
| 135 | + instance_labels = remapper_instance[instance_labels.astype(np.int32)] |
| 136 | + |
| 137 | + uniqueSemantics = (np.unique(sem_labels))[1:].astype(np.int32) |
| 138 | + |
| 139 | + if split == 'train' and (len(uniqueInstances) < 10 or |
| 140 | + (len(uniqueSemantics) >= |
| 141 | + (len(uniqueInstances) - 2))): |
| 142 | + print('unique insance: %d' % len(uniqueInstances)) |
| 143 | + print('unique semantic: %d' % len(uniqueSemantics)) |
| 144 | + print() |
| 145 | + counter += 1 |
| 146 | + else: |
| 147 | + torch.save((coords, colors, sem_labels, instance_labels), outFilePath) |
| 148 | + else: |
| 149 | + torch.save((coords, colors), outFilePath) |
| 150 | + print('Total skipped file :%d' % counter) |
| 151 | + json.dump(coordShift, open(outJsonPath, 'w')) |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == '__main__': |
| 155 | + data_folder = 'Synthetic_v3_InstanceSegmentation' |
| 156 | + filesOri = sorted(glob.glob(data_folder + '/*.txt')) |
| 157 | + |
| 158 | + trainSplit = [1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23, 24] |
| 159 | + trainFiles = getFiles(filesOri, trainSplit) |
| 160 | + split = 'train' |
| 161 | + trainOutDir = split |
| 162 | + os.makedirs(trainOutDir, exist_ok=True) |
| 163 | + preparePthFiles(trainFiles, split, trainOutDir, AugTimes=6) |
| 164 | + |
| 165 | + valSplit = [5, 10, 15, 20, 25] |
| 166 | + split = 'val' |
| 167 | + valFiles = getFiles(filesOri, valSplit) |
| 168 | + valOutDir = split |
| 169 | + os.makedirs(valOutDir, exist_ok=True) |
| 170 | + preparePthFiles(valFiles, split, valOutDir) |
0 commit comments