Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dense to Sparse conversion keeps zero values #317

Closed
AlexeyGB opened this issue Feb 16, 2021 · 1 comment
Closed

Dense to Sparse conversion keeps zero values #317

AlexeyGB opened this issue Feb 16, 2021 · 1 comment

Comments

@AlexeyGB
Copy link
Contributor

Describe the bug
I have a dense tensor with 3D data (voxels from ScanNet) and I want to convert it to sparse for sparse NN training. I did not found any way except the ME.to_sparse() function (MinkowskiToSparseTensor module uses same function under the hood). But when I converted my dense tensor to sparse I found, that new sparse tensor stores all values of dense tensor as features including zero values.
I believe that this behaviour of the ME.to_sparse()` function is incorrect and a sparse tensor should not store zero values by definition.
If I am wrong please correct me and tell me how I should convert a dense tensor to a sparse one for further efficient operation.

To Reproduce
I use a simpler tensor to reproduce the behaviour.

dense_batch = torch.tensor(
    [[[[0., 0.],
      [5., 0.]]],

    [[[0., 5.],
      [0., 0.]]],

    [[[0., 0.],
      [5., 0.]]]])
# shape is [3, 1, 2, 2]

sparse = ME.to_sparse(dense_batch)
sparse

Returns:

SparseTensor(
  coordinates=tensor([[0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 1, 1],
        [1, 0, 0],
        [1, 0, 1],
        [1, 1, 0],
        [1, 1, 1],
        [2, 0, 0],
        [2, 0, 1],
        [2, 1, 0],
        [2, 1, 1]], dtype=torch.int32)
  features=tensor([[0.],
        [0.],
        [5.],
        [0.],
        [0.],
        [5.],
        [0.],
        [0.],
        [0.],
        [0.],
        [5.],
        [0.]])
  coordinate_map_key=coordinate map key:[1, 1]
  coordinate_manager=CoordinateMapManagerCPU(
	[1, 1, ��]:	CoordinateMapCPU:12x3
	algorithm=MinkowskiAlgorithm.DEFAULT
  )
  spatial dimension=2)

Expected behavior
I expect that the resulting sparse tensor will store only non-zero values.

Desktop (please complete the following information):

OS: Ubuntu 18.04.4 LTS
Python version: 3.8.1
CUDA version: 10.2
NVIDIA Driver version: 440.64
Minkowski Engine version: 0.5.1

  • Output of the following command. (If you installed the latest MinkowskiEngine, simply call MinkowskiEngine.print_diagnostics())
wget -q https://raw.githubusercontent.com/NVIDIA/MinkowskiEngine/master/MinkowskiEngine/diagnostics.py ; python diagnostics.py
==========System==========
Linux-5.3.0-28-generic-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.4 LTS"
3.8.1 (default, Jan  8 2020, 22:29:32) 
[GCC 7.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 440.64
CUDA Version 10.2
VBIOS Version 90.02.2E.00.0C
Image Version G001.0000.02.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.1
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 10020
CUDART version MinkowskiEngine is compiled: 10020

Additional context

@AlexeyGB AlexeyGB changed the title Dense to Sparse conversion Dense to Sparse conversion keeps zero values Feb 17, 2021
@chrischoy
Copy link
Contributor

Renamed the previous to_sparse function to to_sparse_all. Added a new to_sparse that removes the zero valued coordinates on

def to_sparse(x: torch.Tensor, format: str = None, coordinates=None, device=None):
.

Please read the docstring to provide the correct format.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants