@@ -276,8 +276,49 @@ def dense_coordinates(shape: Union[list, torch.Size]):
276
276
return coordinates
277
277
278
278
279
- def to_sparse (dense_tensor : torch .Tensor , coordinates : torch .Tensor = None ):
280
- r"""Converts a (differentiable) dense tensor to a sparse tensor.
279
+ def to_sparse (x : torch .Tensor , format : str = None , coordinates = None , device = None ):
280
+ r"""Convert a batched tensor (dimension 0 is the batch dimension) to a SparseTensor
281
+
282
+ :attr:`x` (:attr:`torch.Tensor`): a batched tensor. The first dimension is the batch dimension.
283
+
284
+ :attr:`format` (:attr:`str`): Format of the tensor. It must include 'B' and 'C' indicating the batch and channel dimension respectively. The rest of the dimensions must be 'X'. .e.g. format="BCXX" if image data with BCHW format is used. If a 3D data with the channel at the last dimension, use format="BXXXC" indicating Batch X Height X Width X Depth X Channel. If not provided, the format will be "BCX...X".
285
+
286
+ :attr:`device`: Device the sparse tensor will be generated on. If not provided, the device of the input tensor will be used.
287
+
288
+ """
289
+ assert x .ndim > 2 , "Input has 0 spatial dimension."
290
+ assert isinstance (x , torch .Tensor )
291
+ if format is None :
292
+ format = [
293
+ "X" ,
294
+ ] * x .ndim
295
+ format [0 ] = "B"
296
+ format [1 ] = "C"
297
+ format = "" .join (format )
298
+ assert x .ndim == len (format ), f"Invalid format: { format } . len(format) != x.ndim"
299
+ assert (
300
+ "B" in format and "B" == format [0 ] and format .count ("B" ) == 1
301
+ ), "The input must have the batch axis and the format must include 'B' indicating the batch axis."
302
+ assert (
303
+ "C" in format and format .count ("C" ) == 1
304
+ ), "The format must indicate the channel axis"
305
+ if device is None :
306
+ device = x .device
307
+ ch_dim = format .find ("C" )
308
+ reduced_x = torch .abs (x ).sum (ch_dim )
309
+ bcoords = torch .where (reduced_x != 0 )
310
+ stacked_bcoords = torch .stack (bcoords , dim = 1 ).int ()
311
+ indexing = [f"bcoords[{ i } ]" for i in range (len (bcoords ))]
312
+ indexing .insert (ch_dim , ":" )
313
+ features = torch .zeros (
314
+ (len (stacked_bcoords ), x .size (ch_dim )), dtype = x .dtype , device = x .device
315
+ )
316
+ exec ("features[:] = x[" + ", " .join (indexing ) + "]" )
317
+ return SparseTensor (features = features , coordinates = stacked_bcoords , device = device )
318
+
319
+
320
+ def to_sparse_all (dense_tensor : torch .Tensor , coordinates : torch .Tensor = None ):
321
+ r"""Converts a (differentiable) dense tensor to a sparse tensor with all coordinates.
281
322
282
323
Assume the input to have BxCxD1xD2x....xDN format.
283
324
@@ -312,6 +353,9 @@ class MinkowskiToSparseTensor(MinkowskiModuleBase):
312
353
313
354
For dense tensor, the input must have the BxCxD1xD2x....xDN format.
314
355
356
+ :attr:`remove_zeros` (bool): if True, removes zero valued coordinates. If
357
+ False, use all coordinates to populate a sparse tensor. True by default.
358
+
315
359
If the shape of the tensor do not change, use `dense_coordinates` to cache the coordinates.
316
360
Please refer to tests/python/dense.py for usage.
317
361
@@ -327,7 +371,7 @@ class MinkowskiToSparseTensor(MinkowskiModuleBase):
327
371
>>> network = nn.Sequential(
328
372
>>> # Add layers that can be applied on a regular pytorch tensor
329
373
>>> nn.ReLU(),
330
- >>> MinkowskiToSparseTensor(coordinates=coordinates),
374
+ >>> MinkowskiToSparseTensor(remove_zeros=False, coordinates=coordinates),
331
375
>>> MinkowskiConvolution(4, 5, kernel_size=3, dimension=4),
332
376
>>> MinkowskiBatchNorm(5),
333
377
>>> MinkowskiReLU(),
@@ -341,16 +385,23 @@ class MinkowskiToSparseTensor(MinkowskiModuleBase):
341
385
342
386
"""
343
387
344
- def __init__ (self , coordinates : torch .Tensor = None ):
388
+ def __init__ (self , remove_zeros = True , coordinates : torch .Tensor = None ):
345
389
MinkowskiModuleBase .__init__ (self )
390
+ assert (
391
+ remove_zeros and coordinates is None
392
+ ), "The coordinates argument cannot be used with remove_zeros=True. If you want to use the coordinates argument, provide remove_zeros=False."
393
+ self .remove_zeros = remove_zeros
346
394
self .coordinates = coordinates
347
395
348
396
def forward (self , input : Union [TensorField , torch .Tensor ]):
349
397
if isinstance (input , TensorField ):
350
398
return input .sparse ()
351
399
elif isinstance (input , torch .Tensor ):
352
400
# dense tensor to sparse tensor conversion
353
- return to_sparse (input , self .coordinates )
401
+ if self .remove_zeros :
402
+ return to_sparse (input )
403
+ else :
404
+ return to_sparse_all (input , self .coordinates )
354
405
else :
355
406
raise ValueError (
356
407
"Unsupported type. Only TensorField and torch.Tensor are supported"
0 commit comments