Skip to content

Commit e3de4bc

Browse files
committed
fix #20
1 parent 226e778 commit e3de4bc

File tree

6 files changed

+19
-23
lines changed

6 files changed

+19
-23
lines changed

config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def add_argument_group(name):
103103
data_arg.add_argument('--test_batch_size', type=int, default=1)
104104
data_arg.add_argument('--cache_data', type=str2bool, default=False)
105105
data_arg.add_argument(
106-
'--threads', type=int, default=1, help='num threads for train/test dataloader')
107-
data_arg.add_argument('--val_threads', type=int, default=1, help='num threads for val dataloader')
106+
'--num_workers', type=int, default=1, help='num workers for train/test dataloader')
107+
data_arg.add_argument('--num_val_workers', type=int, default=1, help='num workers for val dataloader')
108108
data_arg.add_argument('--ignore_label', type=int, default=255)
109109
data_arg.add_argument('--return_transformation', type=str2bool, default=False)
110110
data_arg.add_argument('--ignore_duplicate_class', type=str2bool, default=False)

lib/dataset.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import MinkowskiEngine as ME
1313

14-
from lib.pc_utils import read_plyfile
14+
from plyfile import PlyData
1515
import lib.transforms as t
1616
from lib.dataloader import InfSampler
1717
from lib.voxelizer import Voxelizer
@@ -173,7 +173,12 @@ def __getitem__(self, index):
173173

174174
def load_ply(self, index):
175175
filepath = self.data_root / self.data_paths[index]
176-
return read_plyfile(filepath), None
176+
plydata = PlyData.read(filepath)
177+
data = plydata.elements[0].data
178+
coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
179+
feats = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T
180+
labels = np.array(data['label'], dtype=np.int32)
181+
return coords, feats, labels, None
177182

178183
def __len__(self):
179184
num_data = len(self.data_paths)
@@ -412,7 +417,7 @@ def __len__(self):
412417
def initialize_data_loader(DatasetClass,
413418
config,
414419
phase,
415-
threads,
420+
num_workers,
416421
shuffle,
417422
repeat,
418423
augment_data,
@@ -467,7 +472,7 @@ def initialize_data_loader(DatasetClass,
467472

468473
data_args = {
469474
'dataset': dataset,
470-
'num_workers': threads,
475+
'num_workers': num_workers,
471476
'batch_size': batch_size,
472477
'collate_fn': collate_fn,
473478
}

lib/datasets/stanford.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
import numpy as np
55
from collections import defaultdict
66
from scipy import spatial
7-
from plyfile import PlyData, PlyElement
7+
from plyfile import PlyData
88

9-
import torch
10-
11-
from lib.pc_utils import read_plyfile
129
from lib.utils import read_txt, fast_hist, per_class_iu
1310
from lib.dataset import VoxelizationDataset, DatasetPhase, str2datasetphase_type, cache
1411
import lib.transforms as t

lib/train.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ def train(model, data_loader, val_data_loader, config, transform_data_fn=None):
7272
for sub_iter in range(config.iter_size):
7373
# Get training data
7474
data_timer.tic()
75-
if config.return_transformation:
76-
coords, input, target, pointcloud, transformation = data_iter.next()
77-
else:
78-
coords, input, target = data_iter.next()
75+
coords, input, target = data_iter.next()
7976

8077
# For some networks, making the network invariant to even, odd coords is important
8178
coords[:, :3] += (torch.rand(3) * 100).type_as(coords)

lib/transforms.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -287,19 +287,16 @@ def __init__(self, limit_numpoints):
287287
self.limit_numpoints = limit_numpoints
288288

289289
def __call__(self, list_data):
290-
coords, feats, labels, pointclouds, transformations = list(zip(*list_data))
290+
coords, feats, labels, transformations = list(zip(*list_data))
291291
cfl_collate_fn = cfl_collate_fn_factory(limit_numpoints=self.limit_numpoints)
292292
coords_batch, feats_batch, labels_batch = cfl_collate_fn(list(zip(coords, feats, labels)))
293293
num_truncated_batch = coords_batch[:, -1].max().item() + 1
294294

295295
batch_id = 0
296296
pointclouds_batch, transformations_batch = [], []
297-
for pointcloud, transformation in zip(pointclouds, transformations):
297+
for transformation in transformations:
298298
if batch_id >= num_truncated_batch:
299299
break
300-
pointclouds_batch.append(
301-
torch.cat((torch.from_numpy(pointcloud), torch.ones(pointcloud.shape[0], 1) * batch_id),
302-
1))
303300
transformations_batch.append(
304301
torch.cat(
305302
(torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id),
@@ -308,4 +305,4 @@ def __call__(self, list_data):
308305

309306
pointclouds_batch = torch.cat(pointclouds_batch, 0).float()
310307
transformations_batch = torch.cat(transformations_batch, 0).float()
311-
return coords_batch, feats_batch, labels_batch, pointclouds_batch, transformations_batch
308+
return coords_batch, feats_batch, labels_batch, transformations_batch

main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def main():
7070
DatasetClass,
7171
config,
7272
phase=config.train_phase,
73-
threads=config.threads,
73+
num_workers=config.num_workers,
7474
augment_data=True,
7575
shuffle=True,
7676
repeat=True,
@@ -80,7 +80,7 @@ def main():
8080
val_data_loader = initialize_data_loader(
8181
DatasetClass,
8282
config,
83-
threads=config.val_threads,
83+
num_workers=config.num_val_workers,
8484
phase=config.val_phase,
8585
augment_data=False,
8686
shuffle=True,
@@ -97,7 +97,7 @@ def main():
9797
test_data_loader = initialize_data_loader(
9898
DatasetClass,
9999
config,
100-
threads=config.threads,
100+
num_workers=config.num_workers,
101101
phase=config.test_phase,
102102
augment_data=False,
103103
shuffle=False,

0 commit comments

Comments
 (0)