7
7
import scipy .interpolate
8
8
import torch
9
9
10
+ import MinkowskiEngine as ME
11
+
10
12
11
13
# A sparse tensor consists of coordinates and associated features.
12
14
# You must apply augmentation to both.
@@ -260,19 +262,15 @@ def __call__(self, list_data):
260
262
f'limit. Truncating batch size at { batch_id } out of { num_full_batch_size } with { batch_num_points - num_points } .'
261
263
)
262
264
break
263
- coords_batch .append (
264
- torch .cat ((torch .from_numpy (
265
- coords [batch_id ]).int (), torch .ones (num_points , 1 ).int () * batch_id ), 1 ))
265
+ coords_batch .append (torch .from_numpy (coords [batch_id ]).int ())
266
266
feats_batch .append (torch .from_numpy (feats [batch_id ]))
267
267
labels_batch .append (torch .from_numpy (labels [batch_id ]).int ())
268
268
269
269
batch_id += 1
270
270
271
271
# Concatenate all lists
272
- coords_batch = torch .cat (coords_batch , 0 ).int ()
273
- feats_batch = torch .cat (feats_batch , 0 ).float ()
274
- labels_batch = torch .cat (labels_batch , 0 ).int ()
275
- return coords_batch , feats_batch , labels_batch
272
+ coords_batch , feats_batch , labels_batch = ME .utils .sparse_collate (coords_batch , feats_batch , labels_batch )
273
+ return coords_batch , feats_batch .float (), labels_batch
276
274
277
275
278
276
class cflt_collate_fn_factory :
@@ -293,16 +291,11 @@ def __call__(self, list_data):
293
291
num_truncated_batch = coords_batch [:, - 1 ].max ().item () + 1
294
292
295
293
batch_id = 0
296
- pointclouds_batch , transformations_batch = [], []
294
+ transformations_batch = []
297
295
for transformation in transformations :
298
296
if batch_id >= num_truncated_batch :
299
297
break
300
- transformations_batch .append (
301
- torch .cat (
302
- (torch .from_numpy (transformation ), torch .ones (transformation .shape [0 ], 1 ) * batch_id ),
303
- 1 ))
298
+ transformations_batch .append (torch .from_numpy (transformation ).float ())
304
299
batch_id += 1
305
300
306
- pointclouds_batch = torch .cat (pointclouds_batch , 0 ).float ()
307
- transformations_batch = torch .cat (transformations_batch , 0 ).float ()
308
301
return coords_batch , feats_batch , labels_batch , transformations_batch
0 commit comments