Skip to content

Commit 36d67fe

Browse files
committed
synthia temporal, network cleanup
1 parent 1fa6da7 commit 36d67fe

10 files changed

+180
-236
lines changed

config.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def add_argument_group(name):
3030

3131
# Network
3232
net_arg = add_argument_group('Network')
33-
net_arg.add_argument(
34-
'--model', type=str, default='ResUNet14', help='Model name')
33+
net_arg.add_argument('--model', type=str, default='ResUNet14', help='Model name')
3534
net_arg.add_argument(
3635
'--conv1_kernel_size', type=int, default=3, help='First layer conv kernel size')
3736
net_arg.add_argument('--weights', type=str, default='None', help='Saved weights to load')
@@ -117,8 +116,16 @@ def add_argument_group(name):
117116
data_arg.add_argument(
118117
'--synthia_path',
119118
type=str,
120-
default='/home/chrischoy/datasets/synthia_preprocessed',
119+
default='/home/chrischoy/datasets/Synthia/Synthia4D',
121120
help='Point Cloud dataset root dir')
121+
# For temporal sequences
122+
data_arg.add_argument(
123+
'--synthia_camera_path', type=str, default='/home/chrischoy/datasets/Synthia/%s/CameraParams/')
124+
data_arg.add_argument('--synthia_camera_intrinsic_file', type=str, default='intrinsics.txt')
125+
data_arg.add_argument(
126+
'--synthia_camera_extrinsics_file', type=str, default='Stereo_Right/Omni_F/%s.txt')
127+
data_arg.add_argument('--temporal_rand_dilation', type=str2bool, default=False)
128+
data_arg.add_argument('--temporal_rand_numseq', type=str2bool, default=False)
122129

123130
data_arg.add_argument(
124131
'--scannet_path',
@@ -179,7 +186,10 @@ def add_argument_group(name):
179186
data_aug_arg.add_argument(
180187
'--data_aug_hue_max', type=float, default=0.5, help='Hue translation range. [0, 1]')
181188
data_aug_arg.add_argument(
182-
'--data_aug_saturation_max', type=float, default=0.20, help='Saturation translation range, [0, 1]')
189+
'--data_aug_saturation_max',
190+
type=float,
191+
default=0.20,
192+
help='Saturation translation range, [0, 1]')
183193

184194
# Test
185195
test_arg = add_argument_group('Test')

lib/dataset.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class VoxelizationDataset(VoxelizationDatasetBase):
194194
ELASTIC_DISTORT_PARAMS = None
195195

196196
# MISC.
197-
PREVOXELIZE_VOXEL_SIZE = None
197+
PREVOXELIZATION_VOXEL_SIZE = None
198198

199199
def __init__(self,
200200
data_paths,
@@ -252,9 +252,9 @@ def __getitem__(self, index):
252252
pointcloud, center = self.load_ply(index)
253253

254254
# Downsample the pointcloud with finer voxel size before transformation for memory and speed
255-
if self.PREVOXELIZE_VOXEL_SIZE is not None:
255+
if self.PREVOXELIZATION_VOXEL_SIZE is not None:
256256
inds = ME.utils.sparse_quantize(
257-
pointcloud[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True)
257+
pointcloud[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
258258
pointcloud = pointcloud[inds]
259259

260260
# Prevoxel transformations
@@ -296,9 +296,18 @@ def __init__(self,
296296
augment_data=False,
297297
config=None,
298298
**kwargs):
299-
VoxelizationDataset.__init__(self, data_paths, input_transform, target_transform, data_root,
300-
ignore_label, return_transformation, augment_data, config,
301-
**kwargs)
299+
VoxelizationDataset.__init__(
300+
self,
301+
data_paths,
302+
prevoxel_transform=prevoxel_transform,
303+
input_transform=input_transform,
304+
target_transform=target_transform,
305+
data_root=data_root,
306+
ignore_label=ignore_label,
307+
return_transformation=return_transformation,
308+
augment_data=augment_data,
309+
config=config,
310+
**kwargs)
302311
self.temporal_dilation = temporal_dilation
303312
self.temporal_numseq = temporal_numseq
304313
temporal_window = temporal_dilation * (temporal_numseq - 1) + 1
@@ -333,10 +342,11 @@ def __getitem__(self, index):
333342
ptcs, centers = zip(*world_pointclouds)
334343

335344
# Downsample pointcloud for speed and memory
336-
if self.PREVOXELIZE_VOXEL_SIZE is not None:
345+
if self.PREVOXELIZATION_VOXEL_SIZE is not None:
337346
new_ptcs = []
338347
for ptc in ptcs:
339-
inds = ME.utils.sparse_quantize(ptc[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True)
348+
inds = ME.utils.sparse_quantize(
349+
ptc[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
340350
new_ptcs.append(ptc[inds])
341351
ptcs = new_ptcs
342352

lib/datasets/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .synthia import SynthiaCVPR15cmVoxelizationDataset, SynthiaCVPR30cmVoxelizationDataset, \
2-
SynthiaAllSequencesVoxelizationDataset
2+
SynthiaAllSequencesVoxelizationDataset, SynthiaTemporalVoxelizationDataset
33
from .stanford import StanfordVoxelizationDataset, StanfordVoxelization2cmDataset
44
from .scannet import ScannetVoxelizationDataset, ScannetVoxelization2cmDataset
55

66
DATASETS = [
77
StanfordVoxelizationDataset, StanfordVoxelization2cmDataset, ScannetVoxelizationDataset,
88
ScannetVoxelization2cmDataset, SynthiaCVPR15cmVoxelizationDataset,
9-
SynthiaCVPR30cmVoxelizationDataset, SynthiaAllSequencesVoxelizationDataset
9+
SynthiaCVPR30cmVoxelizationDataset, SynthiaTemporalVoxelizationDataset,
10+
SynthiaAllSequencesVoxelizationDataset
1011
]
1112

1213

lib/datasets/synthia.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class SynthiaVoxelizationDataset(VoxelizationDataset):
110110
TEST_CLIP_BOUND = ((-2500, 2500), (-2500, 2500), (-2500, 2500))
111111
VOXEL_SIZE = 15 # cm
112112

113-
PREVOXELIZE_VOXEL_SIZE = 7.5
113+
PREVOXELIZATION_VOXEL_SIZE = 7.5
114114
# Elastic distortion, (granularity, magitude) pairs
115115
ELASTIC_DISTORT_PARAMS = ((80, 300),)
116116

@@ -166,7 +166,7 @@ class SynthiaTemporalVoxelizationDataset(TemporalVoxelizationDataset):
166166
TEST_CLIP_BOUND = ((-2500, 2500), (-2500, 2500), (-2500, 2500))
167167
VOXEL_SIZE = 15 # cm
168168

169-
PREVOXELIZE_VOXEL_SIZE = 7.5
169+
PREVOXELIZATION_VOXEL_SIZE = 7.5
170170
# For temporal sequences, the voxel locations has to be aligned exactly.
171171
ELASTIC_DISTORT_PARAMS = None
172172

@@ -179,21 +179,27 @@ class SynthiaTemporalVoxelizationDataset(TemporalVoxelizationDataset):
179179
NUM_LABELS = 16 # Automatically subtract ignore labels after processed
180180
IGNORE_LABELS = (0, 1, 13, 14) # void, sky, reserved, reserved
181181

182+
# Split used in the Minkowski ConvNet, CVPR'19
183+
DATA_PATH_FILE = {
184+
DatasetPhase.Train: 'train_cvpr19.txt',
185+
DatasetPhase.Val: 'val_cvpr19.txt',
186+
DatasetPhase.Test: 'test_cvpr19.txt'
187+
}
188+
182189
def __init__(self,
183190
config,
184191
prevoxel_transform=None,
185192
input_transform=None,
186193
target_transform=None,
187194
augment_data=True,
188-
elastic_distortion=False,
189195
cache=False,
190196
phase=DatasetPhase.Train):
191197
if isinstance(phase, str):
192198
phase = str2datasetphase_type(phase)
193199
if phase not in [DatasetPhase.Train, DatasetPhase.TrainVal]:
194200
self.CLIP_BOUND = self.TEST_CLIP_BOUND
195201
data_root = config.synthia_path
196-
data_paths = read_txt(osp.join(data_root, self.DATA_PATH_FILE[phase]))
202+
data_paths = read_txt(osp.join('./splits/synthia4d', self.DATA_PATH_FILE[phase]))
197203
data_paths = sorted([d.split()[0] for d in data_paths])
198204
seq2files = defaultdict(list)
199205
for f in data_paths:
@@ -211,15 +217,15 @@ def __init__(self,
211217
TemporalVoxelizationDataset.__init__(
212218
self,
213219
file_seq_list,
214-
data_root=data_root,
220+
prevoxel_transform=prevoxel_transform,
215221
input_transform=input_transform,
216222
target_transform=target_transform,
223+
data_root=data_root,
217224
ignore_label=config.ignore_label,
218225
temporal_dilation=config.temporal_dilation,
219226
temporal_numseq=config.temporal_numseq,
220227
return_transformation=config.return_transformation,
221228
augment_data=augment_data,
222-
elastic_distortion=elastic_distortion,
223229
config=config)
224230

225231
def load_world_pointcloud(self, filename):

lib/pc_utils.py

-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import os
2-
import logging
32
import numpy as np
43
from numpy.linalg import matrix_rank, inv
54
from plyfile import PlyData, PlyElement
65
import pandas as pd
7-
from retrying import retry
86

97
COLOR_MAP_RGB = (
108
(241, 255, 82),
@@ -27,16 +25,6 @@
2725
IGNORE_COLOR = (0, 0, 0)
2826

2927

30-
def retry_on_ioerror(exc):
31-
logging.warning("Retrying file load")
32-
return isinstance(exc, IOError)
33-
34-
35-
@retry(
36-
retry_on_exception=retry_on_ioerror,
37-
wait_exponential_multiplier=1000,
38-
wait_exponential_max=10000,
39-
stop_max_delay=30000)
4028
def read_plyfile(filepath):
4129
"""Read ply file and return it as numpy array. Returns None if emtpy."""
4230
with open(filepath, 'rb') as f:

lib/voxelizer.py

+57
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,63 @@ def voxelize(self, coords, feats, labels, center=None):
132132

133133
return coords_aug, feats, labels, rigid_transformation.flatten()
134134

135+
def voxelize_temporal(self,
136+
coords_t,
137+
feats_t,
138+
labels_t,
139+
centers=None,
140+
return_transformation=False):
141+
# Legacy code, remove
142+
if centers is None:
143+
centers = [None, ] * len(coords_t)
144+
coords_tc, feats_tc, labels_tc, transformation_tc = [], [], [], []
145+
146+
# ######################### Data Augmentation #############################
147+
# Get rotation and scale
148+
M_v, M_r = self.get_transformation_matrix()
149+
# Apply transformations
150+
rigid_transformation = M_v
151+
if self.use_augmentation:
152+
rigid_transformation = M_r @ rigid_transformation
153+
# ######################### Voxelization #############################
154+
# Voxelize coords
155+
for coords, feats, labels, center in zip(coords_t, feats_t, labels_t, centers):
156+
157+
###################################
158+
# Clip the data if bound exists
159+
if self.clip_bound is not None:
160+
trans_aug_ratio = np.zeros(3)
161+
if self.use_augmentation and self.translation_augmentation_ratio_bound is not None:
162+
for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound):
163+
trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound)
164+
165+
clip_inds = self.clip(coords, center, trans_aug_ratio)
166+
coords, feats = coords[clip_inds], feats[clip_inds]
167+
if labels is not None:
168+
labels = labels[clip_inds]
169+
###################################
170+
171+
homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype)))
172+
coords_aug = np.floor(homo_coords @ rigid_transformation.T)[:, :3]
173+
174+
inds = ME.utils.sparse_quantize(coords_aug, return_index=True)
175+
coords_aug, feats, labels = coords_aug[inds], feats[inds], labels[inds]
176+
177+
# If use normal rotation
178+
if feats.shape[1] > 6:
179+
feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T)
180+
181+
coords_tc.append(coords_aug)
182+
feats_tc.append(feats)
183+
labels_tc.append(labels)
184+
transformation_tc.append(rigid_transformation.flatten())
185+
186+
return_args = [coords_tc, feats_tc, labels_tc]
187+
if return_transformation:
188+
return_args.append(transformation_tc)
189+
190+
return tuple(return_args)
191+
135192

136193
def test():
137194
N = 16575

models/res16unet.py

+5-62
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from models.resnet import ResNetBase, get_norm
22
from models.modules.common import ConvType, NormType, conv, conv_tr
3-
from models.modules.resnet_block import BasicBlock, Bottleneck, BasicBlockIN, BottleneckIN, BasicBlockLN
3+
from models.modules.resnet_block import BasicBlock, Bottleneck
44

55
from MinkowskiEngine import MinkowskiReLU
66
import MinkowskiEngine.MinkowskiOps as me
@@ -331,67 +331,6 @@ class Res16UNet34C(Res16UNet34):
331331
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
332332

333333

334-
# Experimentally, worse than others
335-
class Res16UNetLN14(Res16UNet14):
336-
NORM_TYPE = NormType.SPARSE_LAYER_NORM
337-
BLOCK = BasicBlockLN
338-
339-
340-
class Res16UNetTemporalBase(Res16UNetBase):
341-
"""
342-
Res16UNet that can take 4D independently. No temporal convolution.
343-
"""
344-
CONV_TYPE = ConvType.SPATIAL_HYPERCUBE
345-
346-
def __init__(self, in_channels, out_channels, config, D=4, **kwargs):
347-
super(Res16UNetTemporalBase, self).__init__(in_channels, out_channels, config, D, **kwargs)
348-
349-
350-
class Res16UNetTemporal14(Res16UNet14, Res16UNetTemporalBase):
351-
pass
352-
353-
354-
class Res16UNetTemporal18(Res16UNet18, Res16UNetTemporalBase):
355-
pass
356-
357-
358-
class Res16UNetTemporal34(Res16UNet34, Res16UNetTemporalBase):
359-
pass
360-
361-
362-
class Res16UNetTemporal50(Res16UNet50, Res16UNetTemporalBase):
363-
pass
364-
365-
366-
class Res16UNetTemporal101(Res16UNet101, Res16UNetTemporalBase):
367-
pass
368-
369-
370-
class Res16UNetTemporalIN14(Res16UNetTemporal14):
371-
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
372-
BLOCK = BasicBlockIN
373-
374-
375-
class Res16UNetTemporalIN18(Res16UNetTemporal18):
376-
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
377-
BLOCK = BasicBlockIN
378-
379-
380-
class Res16UNetTemporalIN34(Res16UNetTemporal34):
381-
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
382-
BLOCK = BasicBlockIN
383-
384-
385-
class Res16UNetTemporalIN50(Res16UNetTemporal50):
386-
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
387-
BLOCK = BottleneckIN
388-
389-
390-
class Res16UNetTemporalIN101(Res16UNetTemporal101):
391-
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
392-
BLOCK = BottleneckIN
393-
394-
395334
class STRes16UNetBase(Res16UNetBase):
396335

397336
CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS
@@ -404,6 +343,10 @@ class STRes16UNet14(STRes16UNetBase, Res16UNet14):
404343
pass
405344

406345

346+
class STRes16UNet14A(STRes16UNetBase, Res16UNet14A):
347+
pass
348+
349+
407350
class STRes16UNet18(STRes16UNetBase, Res16UNet18):
408351
pass
409352

0 commit comments

Comments
 (0)