Skip to content

Commit c4140c7

Browse files
committed
stanford training
1 parent c6e421a commit c4140c7

23 files changed

+923
-161
lines changed

README.md

+29
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,38 @@ export BATCH_SIZE=N; \
6363
The above script trains a network. You have to change the arguments accordingly. The first argument to the script is the GPU id. Second argument is the log directory postfix; change to mark your experimental setup. The final argument is a series of the miscellaneous aruments. You have to specify the synthia directory here. Also, you have to wrap all arguments with " ".
6464

6565

66+
## Stanford 3D Dataset
67+
68+
1. Download the stanford 3d dataset from [the website](http://buildingparser.stanford.edu/dataset.html)
69+
70+
2. Preprocess
71+
72+
Modify the input and output directory accordingly in
73+
74+
`lib/datasets/preprocessing/stanford.py`
75+
76+
And run
77+
78+
```
79+
python -m lib.datasets.preprocessing.stanford
80+
```
81+
82+
3. Train
83+
84+
Modify the stanford 3d path in the script and run
85+
86+
```
87+
./scripts/train_stanford.sh 0 \
88+
"-default" \
89+
""
90+
```
91+
6692
## Model Zoo
6793

6894
| Model | Dataset | Voxel Size | Conv1 Kernel Size | Performance | Link |
6995
|:-------------:|:-------------------:|:----------:|:-----------------:|:------------------------:|:------:|
7096
| Mink16UNet34C | ScanNet train + val | 2cm | 3 | Test set 73.6% mIoU | [download](https://node1.chrischoy.org/data/publications/minknet/Mink16UNet34C_ScanNet.pth) |
7197
| Mink16UNet34C | ScanNet train | 2cm | 5 | Val 72.219% mIoU without rotation average [per class performance](https://github.com/chrischoy/SpatioTemporalSegmentation/issues/13) | [download](https://node1.chrischoy.org/data/publications/minknet/MinkUNet34C-train-conv1-5.pth) |
98+
| Mink16UNet18 | Stanford Area5 train | 5cm | 5 | Area 5 test 65.483% mIoU w/o rotation average, no sliding window | [download](https://node1.chrischoy.org/data/publications/minknet/Mink16UNet18_stanford-conv1-5.pth) |
99+
100+
Note that sliding window style evaluation (cropping and stitching results) used in many related works effectively works as an ensemble (rotation averaging) which boosts the performance.

config.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def add_argument_group(name):
9393

9494
# Data
9595
data_arg = add_argument_group('Data')
96-
data_arg.add_argument('--dataset', type=str, default='ScannetSparseVoxelization2cmDataset')
96+
data_arg.add_argument('--dataset', type=str, default='ScannetVoxelization2cmDataset')
9797
data_arg.add_argument('--temporal_dilation', type=int, default=30)
9898
data_arg.add_argument('--temporal_numseq', type=int, default=3)
9999
data_arg.add_argument('--point_lim', type=int, default=-1)
@@ -134,15 +134,15 @@ def add_argument_group(name):
134134
help='Scannet online voxelization dataset root dir')
135135

136136
data_arg.add_argument(
137-
'--stanford3d_online_path',
137+
'--stanford3d_path',
138138
type=str,
139-
default='/home/chrischoy/datasets/stanford_preprocessed',
139+
default='/home/chrischoy/datasets/Stanford3D',
140140
help='Stanford precropped dataset root dir')
141141

142142
# Training / test parameters
143143
train_arg = add_argument_group('Training')
144144
train_arg.add_argument('--is_train', type=str2bool, default=True)
145-
train_arg.add_argument('--stat_freq', type=int, default=10, help='print frequency')
145+
train_arg.add_argument('--stat_freq', type=int, default=40, help='print frequency')
146146
train_arg.add_argument('--test_stat_freq', type=int, default=100, help='print frequency')
147147
train_arg.add_argument('--save_freq', type=int, default=1000, help='save frequency')
148148
train_arg.add_argument('--val_freq', type=int, default=1000, help='validation frequency')
@@ -174,12 +174,6 @@ def add_argument_group(name):
174174
'--data_aug_color_trans_ratio', type=float, default=0.10, help='Color translation range')
175175
data_aug_arg.add_argument(
176176
'--data_aug_color_jitter_std', type=float, default=0.05, help='STD of color jitter')
177-
data_aug_arg.add_argument(
178-
'--data_aug_height_trans_std', type=float, default=1, help='STD of height translation')
179-
data_aug_arg.add_argument(
180-
'--data_aug_height_jitter_std', type=float, default=0.1, help='STD of height jitter')
181-
data_aug_arg.add_argument(
182-
'--data_aug_normal_jitter_std', type=float, default=0.01, help='STD of normal jitter')
183177
data_aug_arg.add_argument('--normalize_color', type=str2bool, default=True)
184178
data_aug_arg.add_argument('--data_aug_scale_min', type=float, default=0.9)
185179
data_aug_arg.add_argument('--data_aug_scale_max', type=float, default=1.1)

lib/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import open3d as o3d

lib/dataset.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from enum import Enum
88

9+
import torch
910
from torch.utils.data import Dataset, DataLoader
1011

1112
import MinkowskiEngine as ME
@@ -196,6 +197,9 @@ class VoxelizationDataset(VoxelizationDatasetBase):
196197
# MISC.
197198
PREVOXELIZATION_VOXEL_SIZE = None
198199

200+
# Augment coords to feats
201+
AUGMENT_COORDS_TO_FEATS = False
202+
199203
def __init__(self,
200204
data_paths,
201205
prevoxel_transform=None,
@@ -244,24 +248,33 @@ def __init__(self,
244248
self.label_map = label_map
245249
self.NUM_LABELS -= len(self.IGNORE_LABELS)
246250

251+
def _augment_coords_to_feats(self, coords, feats, labels=None):
252+
norm_coords = coords - coords.mean(0)
253+
# color must come first.
254+
if isinstance(coords, np.ndarray):
255+
feats = np.concatenate((feats, norm_coords), 1)
256+
else:
257+
feats = torch.cat((feats, norm_coords), 1)
258+
return coords, feats, labels
259+
247260
def convert_mat2cfl(self, mat):
248261
# Generally, xyz,rgb,label
249262
return mat[:, :3], mat[:, 3:-1], mat[:, -1]
250263

251264
def __getitem__(self, index):
252-
pointcloud, center = self.load_ply(index)
253-
265+
coords, feats, labels, center = self.load_ply(index)
254266
# Downsample the pointcloud with finer voxel size before transformation for memory and speed
255267
if self.PREVOXELIZATION_VOXEL_SIZE is not None:
256268
inds = ME.utils.sparse_quantize(
257-
pointcloud[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
258-
pointcloud = pointcloud[inds]
269+
coords / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
270+
coords = coords[inds]
271+
feats = feats[inds]
272+
labels = labels[inds]
259273

260274
# Prevoxel transformations
261275
if self.prevoxel_transform is not None:
262-
pointcloud = self.prevoxel_transform(pointcloud)
276+
coords, feats, labels = self.prevoxel_transform(coords, feats, labels)
263277

264-
coords, feats, labels = self.convert_mat2cfl(pointcloud)
265278
coords, feats, labels, transformation = self.voxelizer.voxelize(
266279
coords, feats, labels, center=center)
267280

@@ -273,9 +286,14 @@ def __getitem__(self, index):
273286
if self.IGNORE_LABELS is not None:
274287
labels = np.array([self.label_map[x] for x in labels], dtype=np.int)
275288

289+
# Use coordinate features if config is set
290+
if self.AUGMENT_COORDS_TO_FEATS:
291+
coords, feats, labels = self._augment_coords_to_feats(coords, feats, labels)
292+
276293
return_args = [coords, feats, labels]
277294
if self.return_transformation:
278-
return_args.extend([pointcloud.astype(np.float32), transformation.astype(np.float32)])
295+
return_args.append(transformation.astype(np.float32))
296+
279297
return tuple(return_args)
280298

281299

@@ -319,10 +337,6 @@ def __init__(self,
319337
def load_world_pointcloud(self, filename):
320338
raise NotImplementedError
321339

322-
def convert_mat2cfl(self, mat):
323-
# Generally, xyz,rgb,label
324-
return mat[:, :3], mat[:, 3:-1], mat[:, -1]
325-
326340
def __getitem__(self, index):
327341
for seq_idx, numel in enumerate(self.numels):
328342
if index >= numel:
@@ -353,7 +367,6 @@ def __getitem__(self, index):
353367
# Apply prevoxel transformations
354368
ptcs = [self.prevoxel_transform(ptc) for ptc in ptcs]
355369

356-
ptcs = [self.convert_mat2cfl(ptc) for ptc in ptcs]
357370
coords, feats, labels = zip(*ptcs)
358371
outs = self.voxelizer.voxelize_temporal(
359372
coords, feats, labels, centers=centers, return_transformation=self.return_transformation)

lib/datasets/__init__.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
from .synthia import SynthiaCVPR15cmVoxelizationDataset, SynthiaCVPR30cmVoxelizationDataset, \
2-
SynthiaAllSequencesVoxelizationDataset, SynthiaTemporalVoxelizationDataset
3-
from .stanford import StanfordVoxelizationDataset, StanfordVoxelization2cmDataset
4-
from .scannet import ScannetVoxelizationDataset, ScannetVoxelization2cmDataset
5-
6-
DATASETS = [
7-
StanfordVoxelizationDataset, StanfordVoxelization2cmDataset, ScannetVoxelizationDataset,
8-
ScannetVoxelization2cmDataset, SynthiaCVPR15cmVoxelizationDataset,
9-
SynthiaCVPR30cmVoxelizationDataset, SynthiaTemporalVoxelizationDataset,
10-
SynthiaAllSequencesVoxelizationDataset
11-
]
1+
import lib.datasets.synthia as synthia
2+
import lib.datasets.stanford as stanford
3+
import lib.datasets.scannet as scannet
4+
5+
DATASETS = []
6+
7+
8+
def add_datasets(module):
9+
DATASETS.extend([getattr(module, a) for a in dir(module) if 'Dataset' in a])
10+
11+
12+
add_datasets(stanford)
13+
add_datasets(synthia)
14+
add_datasets(scannet)
1215

1316

1417
def load_dataset(name):

lib/datasets/preprocessing/scannet.py

+31-36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from pathlib import Path
2-
from random import shuffle
32

43
import numpy as np
5-
import sys
64
from lib.pc_utils import read_plyfile, save_point_cloud
75
from concurrent.futures import ProcessPoolExecutor
86
SCANNET_RAW_PATH = Path('/path/ScanNet_data/')
@@ -19,47 +17,44 @@
1917
print('start preprocess')
2018
# Preprocess data.
2119

20+
2221
def handle_process(path):
23-
f = Path(path.split(',')[0])
24-
phase_out_path = Path(path.split(',')[1])
25-
pointcloud = read_plyfile(f)
26-
# Make sure alpha value is meaningless.
27-
assert np.unique(pointcloud[:, -1]).size == 1
28-
# Load label file.
29-
label_f = f.parent / (f.stem + '.labels' + f.suffix)
30-
if label_f.is_file():
31-
label = read_plyfile(label_f)
32-
# Sanity check that the pointcloud and its label has same vertices.
33-
assert pointcloud.shape[0] == label.shape[0]
34-
assert np.allclose(pointcloud[:, :3], label[:, :3])
35-
else: # Label may not exist in test case.
36-
label = np.zeros_like(pointcloud)
37-
xyz = pointcloud[:, :3]
38-
pool = ProcessPoolExecutor(max_workers=9)
39-
all_points = np.empty((0, 3))
40-
out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
41-
processed = np.hstack((pointcloud[:, :6], np.array([label[:, -1]]).T))
42-
save_point_cloud(processed, out_f, with_label=True, verbose=False)
22+
f = Path(path.split(',')[0])
23+
phase_out_path = Path(path.split(',')[1])
24+
pointcloud = read_plyfile(f)
25+
# Make sure alpha value is meaningless.
26+
assert np.unique(pointcloud[:, -1]).size == 1
27+
# Load label file.
28+
label_f = f.parent / (f.stem + '.labels' + f.suffix)
29+
if label_f.is_file():
30+
label = read_plyfile(label_f)
31+
# Sanity check that the pointcloud and its label has same vertices.
32+
assert pointcloud.shape[0] == label.shape[0]
33+
assert np.allclose(pointcloud[:, :3], label[:, :3])
34+
else: # Label may not exist in test case.
35+
label = np.zeros_like(pointcloud)
36+
out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
37+
processed = np.hstack((pointcloud[:, :6], np.array([label[:, -1]]).T))
38+
save_point_cloud(processed, out_f, with_label=True, verbose=False)
39+
4340

4441
path_list = []
4542
for out_path, in_path in SUBSETS.items():
46-
phase_out_path = SCANNET_OUT_PATH / out_path
47-
phase_out_path.mkdir(parents=True, exist_ok=True)
48-
for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
49-
path_list.append(str(f)+','+str(phase_out_path))
43+
phase_out_path = SCANNET_OUT_PATH / out_path
44+
phase_out_path.mkdir(parents=True, exist_ok=True)
45+
for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
46+
path_list.append(str(f) + ',' + str(phase_out_path))
5047

5148
pool = ProcessPoolExecutor(max_workers=20)
52-
result = list(pool.map(handle_process,path_list))
53-
for i in result:
54-
pass
49+
result = list(pool.map(handle_process, path_list))
5550

5651
# Fix bug in the data.
5752
for files, bug_index in BUGS.items():
58-
print(files)
53+
print(files)
5954

60-
for f in SCANNET_OUT_PATH.glob(files):
61-
pointcloud = read_plyfile(f)
62-
bug_mask = pointcloud[:, -1] == bug_index
63-
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
64-
pointcloud[bug_mask, -1] = 0
65-
save_point_cloud(pointcloud, f, with_label=True, verbose=False)
55+
for f in SCANNET_OUT_PATH.glob(files):
56+
pointcloud = read_plyfile(f)
57+
bug_mask = pointcloud[:, -1] == bug_index
58+
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
59+
pointcloud[bug_mask, -1] = 0
60+
save_point_cloud(pointcloud, f, with_label=True, verbose=False)

0 commit comments

Comments
 (0)