Skip to content

Commit 1c937e8

Browse files
committed
ME collation
1 parent 4a4c0c5 commit 1c937e8

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

lib/transforms.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import scipy.interpolate
88
import torch
99

10+
import MinkowskiEngine as ME
11+
1012

1113
# A sparse tensor consists of coordinates and associated features.
1214
# You must apply augmentation to both.
@@ -260,19 +262,15 @@ def __call__(self, list_data):
260262
f'limit. Truncating batch size at {batch_id} out of {num_full_batch_size} with {batch_num_points - num_points}.'
261263
)
262264
break
263-
coords_batch.append(
264-
torch.cat((torch.from_numpy(
265-
coords[batch_id]).int(), torch.ones(num_points, 1).int() * batch_id), 1))
265+
coords_batch.append(torch.from_numpy(coords[batch_id]).int())
266266
feats_batch.append(torch.from_numpy(feats[batch_id]))
267267
labels_batch.append(torch.from_numpy(labels[batch_id]).int())
268268

269269
batch_id += 1
270270

271271
# Concatenate all lists
272-
coords_batch = torch.cat(coords_batch, 0).int()
273-
feats_batch = torch.cat(feats_batch, 0).float()
274-
labels_batch = torch.cat(labels_batch, 0).int()
275-
return coords_batch, feats_batch, labels_batch
272+
coords_batch, feats_batch, labels_batch = ME.utils.sparse_collate(coords_batch, feats_batch, labels_batch)
273+
return coords_batch, feats_batch.float(), labels_batch
276274

277275

278276
class cflt_collate_fn_factory:
@@ -293,16 +291,11 @@ def __call__(self, list_data):
293291
num_truncated_batch = coords_batch[:, -1].max().item() + 1
294292

295293
batch_id = 0
296-
pointclouds_batch, transformations_batch = [], []
294+
transformations_batch = []
297295
for transformation in transformations:
298296
if batch_id >= num_truncated_batch:
299297
break
300-
transformations_batch.append(
301-
torch.cat(
302-
(torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id),
303-
1))
298+
transformations_batch.append(torch.from_numpy(transformation).float())
304299
batch_id += 1
305300

306-
pointclouds_batch = torch.cat(pointclouds_batch, 0).float()
307-
transformations_batch = torch.cat(transformations_batch, 0).float()
308301
return coords_batch, feats_batch, labels_batch, transformations_batch

0 commit comments

Comments
 (0)