Skip to content

Commit 4085407

Browse files
committed
Training scannet, script
1 parent f166209 commit 4085407

File tree

10 files changed

+86
-105
lines changed

10 files changed

+86
-105
lines changed

config.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def add_argument_group(name):
107107
'--threads', type=int, default=1, help='num threads for train/test dataloader')
108108
data_arg.add_argument('--val_threads', type=int, default=1, help='num threads for val dataloader')
109109
data_arg.add_argument('--ignore_label', type=int, default=255)
110-
data_arg.add_argument('--train_elastic_distortion', type=str2bool, default=True)
111-
data_arg.add_argument('--test_elastic_distortion', type=str2bool, default=False)
112110
data_arg.add_argument('--return_transformation', type=str2bool, default=False)
113111
data_arg.add_argument('--ignore_duplicate_class', type=str2bool, default=False)
114112
data_arg.add_argument('--partial_crop', type=float, default=0.)
@@ -198,15 +196,16 @@ def add_argument_group(name):
198196
data_aug_arg.add_argument('--normalize_color', type=str2bool, default=True)
199197
data_aug_arg.add_argument('--data_aug_scale_min', type=float, default=0.9)
200198
data_aug_arg.add_argument('--data_aug_scale_max', type=float, default=1.1)
199+
data_aug_arg.add_argument(
200+
'--data_aug_hue_max', type=float, default=0.5, help='Hue translation range. [0, 1]')
201+
data_aug_arg.add_argument(
202+
'--data_aug_saturation_max', type=float, default=0.20, help='Saturation translation range, [0, 1]')
201203

202204
# Test
203205
test_arg = add_argument_group('Test')
204206
test_arg.add_argument('--visualize', type=str2bool, default=False)
205207
test_arg.add_argument('--test_temporal_average', type=str2bool, default=False)
206208
test_arg.add_argument('--visualize_path', type=str, default='outputs/visualize')
207-
test_arg.add_argument('--test_rotation', type=int, default=-1)
208-
test_arg.add_argument('--test_rotation_save', type=str2bool, default=False)
209-
test_arg.add_argument('--test_rotation_save_dir', type=str, default='outputs/rotation_fulleval')
210209
test_arg.add_argument('--save_prediction', type=str2bool, default=False)
211210
test_arg.add_argument('--save_pred_dir', type=str, default='outputs/pred')
212211
test_arg.add_argument('--test_phase', type=str, default='test', help='Dataset for test')

lib/dataset.py

+14-36
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class DictDataset(Dataset, ABC):
7676

7777
def __init__(self,
7878
data_paths,
79+
prevoxel_transform=None,
7980
input_transform=None,
8081
target_transform=None,
8182
cache=False,
@@ -91,6 +92,8 @@ def __init__(self,
9192

9293
self.data_root = data_root
9394
self.data_paths = sorted(data_paths)
95+
96+
self.prevoxel_transform = prevoxel_transform
9497
self.input_transform = input_transform
9598
self.target_transform = target_transform
9699

@@ -141,28 +144,27 @@ class VoxelizationDatasetBase(DictDataset, ABC):
141144

142145
def __init__(self,
143146
data_paths,
147+
prevoxel_transform=None,
144148
input_transform=None,
145149
target_transform=None,
146150
cache=False,
147151
data_root='/',
148-
explicit_rotation=-1,
149152
ignore_mask=255,
150153
return_transformation=False,
151154
**kwargs):
152155
"""
153156
ignore_mask: label value for ignore class. It will not be used as a class in the loss or evaluation.
154-
explicit_rotation: # of discretization of 360 degree. # data would be num_data * explicit_rotation
155157
"""
156158
DictDataset.__init__(
157159
self,
158160
data_paths,
161+
prevoxel_transform=prevoxel_transform,
159162
input_transform=input_transform,
160163
target_transform=target_transform,
161164
cache=cache,
162165
data_root=data_root)
163166

164167
self.ignore_mask = ignore_mask
165-
self.explicit_rotation = explicit_rotation
166168
self.return_transformation = return_transformation
167169

168170
def __getitem__(self, index):
@@ -174,8 +176,6 @@ def load_ply(self, index):
174176

175177
def __len__(self):
176178
num_data = len(self.data_paths)
177-
if self.explicit_rotation > 1:
178-
return num_data * self.explicit_rotation
179179
return num_data
180180

181181

@@ -202,7 +202,6 @@ def __init__(self,
202202
input_transform=None,
203203
target_transform=None,
204204
data_root='/',
205-
explicit_rotation=-1,
206205
ignore_label=255,
207206
return_transformation=False,
208207
augment_data=False,
@@ -214,11 +213,11 @@ def __init__(self,
214213
VoxelizationDatasetBase.__init__(
215214
self,
216215
data_paths,
216+
prevoxel_transform=prevoxel_transform,
217217
input_transform=input_transform,
218218
target_transform=target_transform,
219219
cache=cache,
220220
data_root=data_root,
221-
explicit_rotation=config.test_rotation,
222221
ignore_mask=ignore_label,
223222
return_transformation=return_transformation)
224223

@@ -250,13 +249,6 @@ def convert_mat2cfl(self, mat):
250249
return mat[:, :3], mat[:, 3:-1], mat[:, -1]
251250

252251
def __getitem__(self, index):
253-
if self.explicit_rotation > 1:
254-
rotation_space = np.linspace(-np.pi, np.pi, self.explicit_rotation + 1)
255-
rotation_angle = rotation_space[index % self.explicit_rotation]
256-
index //= self.explicit_rotation
257-
else:
258-
rotation_angle = None
259-
260252
pointcloud, center = self.load_ply(index)
261253

262254
# Downsample the pointcloud with finer voxel size before transformation for memory and speed
@@ -269,19 +261,8 @@ def __getitem__(self, index):
269261
pointcloud = self.prevoxel_transform(pointcloud)
270262

271263
coords, feats, labels = self.convert_mat2cfl(pointcloud)
272-
outs = self.voxelizer.voxelize(
273-
coords,
274-
feats,
275-
labels,
276-
center=center,
277-
rotation_angle=rotation_angle,
278-
return_transformation=self.return_transformation)
279-
280-
if self.return_transformation:
281-
coords, feats, labels, transformation = outs
282-
transformation = np.expand_dims(transformation, 0)
283-
else:
284-
coords, feats, labels = outs
264+
coords, feats, labels, transformation = self.voxelizer.voxelize(
265+
coords, feats, labels, center=center)
285266

286267
# map labels not used for evaluation to ignore_label
287268
if self.input_transform is not None:
@@ -296,12 +277,6 @@ def __getitem__(self, index):
296277
return_args.extend([pointcloud.astype(np.float32), transformation.astype(np.float32)])
297278
return tuple(return_args)
298279

299-
def __len__(self):
300-
num_data = sum(self.numels)
301-
if self.explicit_rotation > 1:
302-
return num_data * self.explicit_rotation
303-
return num_data
304-
305280

306281
class TemporalVoxelizationDataset(VoxelizationDataset):
307282

@@ -313,7 +288,6 @@ def __init__(self,
313288
input_transform=None,
314289
target_transform=None,
315290
data_root='/',
316-
explicit_rotation=-1,
317291
ignore_label=255,
318292
temporal_dilation=1,
319293
temporal_numseq=3,
@@ -322,8 +296,8 @@ def __init__(self,
322296
config=None,
323297
**kwargs):
324298
VoxelizationDataset.__init__(self, data_paths, input_transform, target_transform, data_root,
325-
explicit_rotation, ignore_label, return_transformation,
326-
augment_data, config, **kwargs)
299+
ignore_label, return_transformation, augment_data, config,
300+
**kwargs)
327301
self.temporal_dilation = temporal_dilation
328302
self.temporal_numseq = temporal_numseq
329303
temporal_window = temporal_dilation * (temporal_numseq - 1) + 1
@@ -406,6 +380,10 @@ def __getitem__(self, index):
406380
return_args.extend([pointclouds.astype(np.float32), transformations.astype(np.float32)])
407381
return tuple(return_args)
408382

383+
def __len__(self):
384+
num_data = sum(self.numels)
385+
return num_data
386+
409387

410388
def initialize_data_loader(DatasetClass,
411389
config,

lib/datasets/scannet.py

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self,
105105
super().__init__(
106106
data_paths,
107107
data_root=data_root,
108+
prevoxel_transform=prevoxel_transform,
108109
input_transform=input_transform,
109110
target_transform=target_transform,
110111
ignore_label=config.ignore_label,

lib/datasets/stanford.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from lib.utils import read_txt, fast_hist, per_class_iu
99
from lib.dataset import VoxelizationDataset, DatasetPhase, str2datasetphase_type
1010
import lib.transforms as t
11-
from lib.datasets.preprocessing.stanford_3d import Stanford3DDatasetConverter
1211

1312

1413
class StanfordVoxelizationDatasetBase:
@@ -69,8 +68,7 @@ def test_pointcloud(self, pred_dir):
6968
ious = []
7069
print('Per class IoU:')
7170
for i, iou in enumerate(per_class_iu(hist) * 100):
72-
unmasked_idx = self.label2masked.tolist().index(i)
73-
result_str = f'\t{Stanford3DDatasetConverter.CLASSES[unmasked_idx]}:\t'
71+
result_str = ''
7472
if hist.sum(1)[i]:
7573
result_str += f'{iou}'
7674
ious.append(iou)

lib/test.py

+2-41
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,6 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
7070
data_iter = data_loader.__iter__()
7171
max_iter = len(data_loader)
7272
max_iter_unique = max_iter
73-
if config.test_rotation > 1:
74-
if config.test_rotation_save:
75-
logging.info('Saving rotation pointcloud prediction at ' + config.test_rotation_save_dir)
76-
os.makedirs(config.test_rotation_save_dir, exist_ok=True)
77-
max_iter_unique //= config.test_rotation
7873

7974
# Fix batch normalization running mean and std
8075
model.eval()
@@ -106,7 +101,7 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
106101
iter_timer.tic()
107102

108103
if config.wrapper_type != 'None':
109-
color = input[:, :3].int()
104+
color = input[:, :3].int()
110105
if config.normalize_color:
111106
input[:, :3] = input[:, :3] / 255. - 0.5
112107
sinput = SparseTensor(input, coords).to(device)
@@ -119,45 +114,13 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
119114
pred = get_prediction(dataset, output, target).int()
120115
iter_time = iter_timer.toc(False)
121116

122-
# Get mapping between input and output space
123-
if np.prod(np.array(model.OUT_PIXEL_DIST)) > 1:
124-
permutation = model.get_permutation(model.OUT_PIXEL_DIST, 1).long()
125-
upsampled_pred = pred[permutation].cpu().numpy()
126-
else:
127-
upsampled_pred = pred.cpu().numpy()
128-
129117
if config.save_prediction or config.test_original_pointcloud:
130-
save_predictions(coords, upsampled_pred, transformation, dataset, config, iteration,
131-
save_pred_dir)
132-
133-
# Visualize prediction
134-
if config.visualize:
135-
# Do not save all predictions in rotation-augmented test.
136-
if config.test_rotation < 1 or iteration % config.test_rotation == 0:
137-
visualize_results(coords, input, target, upsampled_pred, config, iteration)
118+
save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir)
138119

139120
if has_gt:
140-
if config.eval_upsample:
141-
# Upscale the target and predication to the original voxel space
142-
output = output[permutation]
143-
pred = get_prediction(dataset, output, target).int()
144-
145121
if config.evaluate_original_pointcloud:
146122
output, pred, target = permute_pointcloud(coords, pointcloud, transformation,
147123
dataset.label_map, output, pred)
148-
if config.test_rotation > 1:
149-
if iteration % config.test_rotation == 0:
150-
output_rotation = output
151-
else:
152-
output_rotation += output
153-
if iteration % config.test_rotation != config.test_rotation - 1:
154-
continue
155-
iteration //= config.test_rotation
156-
output = output_rotation
157-
pred = get_prediction(dataset, output, target).int()
158-
if config.test_rotation_save:
159-
save_rotation_pred(iteration,
160-
pred.cpu().numpy(), dataset, config.test_rotation_save_dir)
161124

162125
target_np = target.numpy()
163126

@@ -221,8 +184,6 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
221184
if config.test_original_pointcloud:
222185
logging.info('===> Start testing on original pointcloud space.')
223186
dataset.test_pointcloud(save_pred_dir)
224-
if not config.save_prediction:
225-
shutil.rmtree(save_pred_dir)
226187

227188
logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))
228189

lib/transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ def elastic_distortion(self, pointcloud, granularity, magnitude):
195195
return pointcloud
196196

197197
def __call__(self, pointcloud):
198-
if self.distortion_param is not None:
198+
if self.distortion_params is not None:
199199
if random.random() < 0.95:
200-
for granularity, magnitude in self.distortion_param:
200+
for granularity, magnitude in self.distortion_params:
201201
pointcloud = self.elastic_distortion(pointcloud, granularity, magnitude)
202202
return pointcloud
203203

lib/utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,7 @@ def wrapper(*args, **kwargs):
197197

198198

199199
def get_prediction(dataset, output, target):
200-
if dataset.NEED_PRED_POSTPROCESSING:
201-
return dataset.get_prediction(output, target)
202-
else:
203-
return output.max(1)[1]
200+
return output.max(1)[1]
204201

205202

206203
def count_parameters(model):

main.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,9 @@ def main():
5858
if not config.return_transformation:
5959
raise ValueError('Pointcloud evaluation requires config.return_transformation=true.')
6060

61-
if config.test_rotation > 1:
62-
if config.is_train:
63-
raise ValueError('Rotation evaluation should not be used for training.')
64-
if not (config.return_transformation and config.evaluate_original_pointcloud):
65-
raise ValueError('Rotation evaluation requires config.evaluate_original_pointcloud=true and '
66-
'config.return_transformation=true.')
67-
if config.test_original_pointcloud:
68-
raise ValueError('Cannot run rotation evaluation and KD-tree evaluation together.')
61+
if (config.return_transformation ^ config.evaluate_original_pointcloud):
62+
raise ValueError('Rotation evaluation requires config.evaluate_original_pointcloud=true and '
63+
'config.return_transformation=true.')
6964

7065
logging.info('===> Initializing dataloader')
7166
if config.is_train:
@@ -75,7 +70,6 @@ def main():
7570
phase=config.train_phase,
7671
threads=config.threads,
7772
augment_data=True,
78-
elastic_distortion=config.train_elastic_distortion,
7973
shuffle=True,
8074
repeat=True,
8175
batch_size=config.batch_size,
@@ -87,7 +81,6 @@ def main():
8781
threads=config.val_threads,
8882
phase=config.val_phase,
8983
augment_data=False,
90-
elastic_distortion=config.test_elastic_distortion,
9184
shuffle=True,
9285
repeat=False,
9386
batch_size=config.val_batch_size,
@@ -105,7 +98,6 @@ def main():
10598
threads=config.threads,
10699
phase=config.test_phase,
107100
augment_data=False,
108-
elastic_distortion=config.test_elastic_distortion,
109101
shuffle=False,
110102
repeat=False,
111103
batch_size=config.test_batch_size,

models/resnet.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,15 @@ def space_n_time_m(n, m):
7777

7878
def weight_initialization(self):
7979
for m in self.modules():
80-
if isinstance(m, nn.BatchNorm1d):
81-
nn.init.constant_(m.weight, 1)
82-
nn.init.constant_(m.bias, 0)
80+
if isinstance(m, ME.MinkowskiConvolution):
81+
ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
82+
83+
if isinstance(m, ME.MinkowskiConvolutionTranspose):
84+
ME.utils.kaiming_normal_(m.kernel, mode='fan_in', nonlinearity='relu')
85+
86+
if isinstance(m, ME.MinkowskiBatchNorm):
87+
nn.init.constant_(m.bn.weight, 1)
88+
nn.init.constant_(m.bn.bias, 0)
8389

8490
def _make_layer(self,
8591
block,

0 commit comments

Comments
 (0)