Skip to content

Commit 26ddffb

Browse files
committed
init
0 parents  commit 26ddffb

21 files changed

+2940
-0
lines changed

.gitignore

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Byte-compiled / optimized
2+
__pycache__/
3+
*.py[cod]
4+
*.o
5+
*.so
6+
7+
# Text edits
8+
*.swp
9+
*.swo
10+
*.orig
11+
12+
# Training
13+
outputs/
14+
15+
# Profiling
16+
callgrind.out*
17+
*.dSYM
18+
19+
# Misc
20+
.DS_Store

.style.yapf

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[style]
2+
based_on_style = chromium
3+
column_limit = 100
4+
indent_width = 2

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Spatio-Temporal Segmentation
2+
3+
This repository contains the accompanying code for [4D-SpatioTemporal ConvNets: Minkowski Convolutional Neural Networks, CVPR'19](https://arxiv.org/abs/1904.08755).

lib/__init__.py

Whitespace-only changes.

lib/dataloader.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
from torch.utils.data.sampler import Sampler
3+
4+
5+
class InfSampler(Sampler):
6+
"""Samples elements randomly, without replacement.
7+
8+
Arguments:
9+
data_source (Dataset): dataset to sample from
10+
"""
11+
12+
def __init__(self, data_source, shuffle=False):
13+
self.data_source = data_source
14+
self.shuffle = shuffle
15+
self.reset_permutation()
16+
17+
def reset_permutation(self):
18+
perm = len(self.data_source)
19+
if self.shuffle:
20+
perm = torch.randperm(perm)
21+
self._perm = perm.tolist()
22+
23+
def __iter__(self):
24+
return self
25+
26+
def __next__(self):
27+
if len(self._perm) == 0:
28+
self.reset_permutation()
29+
return self._perm.pop()
30+
31+
def __len__(self):
32+
return len(self.data_source)
33+
34+
next = __next__ # Python 2 compatibility

lib/layers.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from MinkowskiEngine import MinkowskiGlobalPooling, MinkowskiBroadcastAddition, MinkowskiBroadcastMultiplication
5+
6+
7+
class MinkowskiLayerNorm(nn.Module):
8+
9+
def __init__(self, num_features, eps=1e-5, D=-1):
10+
super(MinkowskiLayerNorm, self).__init__()
11+
self.num_features = num_features
12+
self.eps = eps
13+
self.weight = nn.Parameter(torch.ones(1, num_features))
14+
self.bias = nn.Parameter(torch.zeros(1, num_features))
15+
16+
self.mean_in = MinkowskiGlobalPooling(dimension=D)
17+
self.glob_sum = MinkowskiBroadcastAddition(dimension=D)
18+
self.glob_sum2 = MinkowskiBroadcastAddition(dimension=D)
19+
self.glob_mean = MinkowskiGlobalPooling(dimension=D)
20+
self.glob_times = MinkowskiBroadcastMultiplication(dimension=D)
21+
self.D = D
22+
self.reset_parameters()
23+
24+
def __repr__(self):
25+
s = f'(D={self.D})'
26+
return self.__class__.__name__ + s
27+
28+
def reset_parameters(self):
29+
self.weight.data.fill_(1)
30+
self.bias.data.zero_()
31+
32+
def _check_input_dim(self, input):
33+
if input.F.dim() != 2:
34+
raise ValueError('expected 2D input (got {}D input)'.format(input.dim()))
35+
36+
def forward(self, x):
37+
self._check_input_dim(x)
38+
mean = self.mean_in(x).F.mean(-1, keepdim=True)
39+
mean = mean + torch.zeros(mean.size(0), self.num_features).type_as(mean)
40+
temp = self.glob_sum(x.F, -mean)**2
41+
var = self.glob_mean(temp.data).mean(-1, keepdim=True)
42+
var = var + torch.zeros(var.size(0), self.num_features).type_as(var)
43+
instd = 1 / (var + self.eps).sqrt()
44+
45+
x = self.glob_times(self.glob_sum2(x, -mean), instd)
46+
return x * self.weight + self.bias
47+
48+
49+
class MinkowskiInstanceNorm(nn.Module):
50+
51+
def __init__(self, num_features, eps=1e-5, D=-1):
52+
super(MinkowskiInstanceNorm, self).__init__()
53+
self.eps = eps
54+
self.weight = nn.Parameter(torch.ones(1, num_features))
55+
self.bias = nn.Parameter(torch.zeros(1, num_features))
56+
57+
self.mean_in = MinkowskiGlobalPooling(dimension=D)
58+
self.glob_sum = MinkowskiBroadcastAddition(dimension=D)
59+
self.glob_sum2 = MinkowskiBroadcastAddition(dimension=D)
60+
self.glob_mean = MinkowskiGlobalPooling(dimension=D)
61+
self.glob_times = MinkowskiBroadcastMultiplication(dimension=D)
62+
self.D = D
63+
self.reset_parameters()
64+
65+
def __repr__(self):
66+
s = f'(pixel_dist={self.pixel_dist}, D={self.D})'
67+
return self.__class__.__name__ + s
68+
69+
def reset_parameters(self):
70+
self.weight.data.fill_(1)
71+
self.bias.data.zero_()
72+
73+
def _check_input_dim(self, input):
74+
if input.dim() != 2:
75+
raise ValueError('expected 2D input (got {}D input)'.format(input.dim()))
76+
77+
def forward(self, x):
78+
self._check_input_dim(x)
79+
mean_in = self.mean_in(x)
80+
temp = self.glob_sum(x, -mean_in)**2
81+
var_in = self.glob_mean(temp.data)
82+
instd_in = 1 / (var_in + self.eps).sqrt()
83+
84+
x = self.glob_times(self.glob_sum2(x, -mean_in), instd_in)
85+
return x * self.weight + self.bias

lib/math_functions.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from scipy.sparse import csr_matrix
2+
import torch
3+
4+
5+
class SparseMM(torch.autograd.Function):
6+
"""
7+
Sparse x dense matrix multiplication with autograd support.
8+
Implementation by Soumith Chintala:
9+
https://discuss.pytorch.org/t/
10+
does-pytorch-support-autograd-on-sparse-matrix/6156/7
11+
"""
12+
13+
def forward(self, matrix1, matrix2):
14+
self.save_for_backward(matrix1, matrix2)
15+
return torch.mm(matrix1, matrix2)
16+
17+
def backward(self, grad_output):
18+
matrix1, matrix2 = self.saved_tensors
19+
grad_matrix1 = grad_matrix2 = None
20+
21+
if self.needs_input_grad[0]:
22+
grad_matrix1 = torch.mm(grad_output, matrix2.t())
23+
24+
if self.needs_input_grad[1]:
25+
grad_matrix2 = torch.mm(matrix1.t(), grad_output)
26+
27+
return grad_matrix1, grad_matrix2
28+
29+
30+
def sparse_float_tensor(values, indices, size=None):
31+
"""
32+
Return a torch sparse matrix give values and indices (row_ind, col_ind).
33+
If the size is an integer, return a square matrix with side size.
34+
If the size is a torch.Size, use it to initialize the out tensor.
35+
If none, the size is inferred.
36+
"""
37+
indices = torch.stack(indices).int()
38+
sargs = [indices, values.float()]
39+
if size is not None:
40+
# Use the provided size
41+
if isinstance(size, int):
42+
size = torch.Size((size, size))
43+
sargs.append(size)
44+
if values.is_cuda:
45+
return torch.cuda.sparse.FloatTensor(*sargs)
46+
else:
47+
return torch.sparse.FloatTensor(*sargs)
48+
49+
50+
def diags(values, size=None):
51+
values = values.view(-1)
52+
n = values.nelement()
53+
size = torch.Size((n, n))
54+
indices = (torch.arange(0, n), torch.arange(0, n))
55+
return sparse_float_tensor(values, indices, size)
56+
57+
58+
def sparse_to_csr_matrix(tensor):
59+
tensor = tensor.cpu()
60+
inds = tensor._indices().numpy()
61+
vals = tensor._values().numpy()
62+
return csr_matrix((vals, (inds[0], inds[1])), shape=[s for s in tensor.shape])
63+
64+
65+
def csr_matrix_to_sparse(mat):
66+
row_ind, col_ind = mat.nonzero()
67+
return sparse_float_tensor(
68+
torch.from_numpy(mat.data),
69+
(torch.from_numpy(row_ind), torch.from_numpy(col_ind)),
70+
size=torch.Size(mat.shape))

lib/solvers.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import logging
2+
3+
from torch.optim import SGD, Adam
4+
from torch.optim.lr_scheduler import LambdaLR, StepLR
5+
6+
7+
class LambdaStepLR(LambdaLR):
8+
9+
def __init__(self, optimizer, lr_lambda, last_step=-1):
10+
super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step)
11+
12+
@property
13+
def last_step(self):
14+
"""Use last_epoch for the step counter"""
15+
return self.last_epoch
16+
17+
@last_step.setter
18+
def last_step(self, v):
19+
self.last_epoch = v
20+
21+
22+
class PolyLR(LambdaStepLR):
23+
"""DeepLab learning rate policy"""
24+
25+
def __init__(self, optimizer, max_iter, power=0.9, last_step=-1):
26+
super(PolyLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**power, last_step)
27+
28+
29+
class SquaredLR(LambdaStepLR):
30+
""" Used for SGD Lars"""
31+
32+
def __init__(self, optimizer, max_iter, last_step=-1):
33+
super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**2, last_step)
34+
35+
36+
class ExpLR(LambdaStepLR):
37+
38+
def __init__(self, optimizer, step_size, gamma=0.9, last_step=-1):
39+
# (0.9 ** 21.854) = 0.1, (0.95 ** 44.8906) = 0.1
40+
# To get 0.1 every N using gamma 0.9, N * log(0.9)/log(0.1) = 0.04575749 N
41+
# To get 0.1 every N using gamma g, g ** N = 0.1 -> N * log(g) = log(0.1) -> g = np.exp(log(0.1) / N)
42+
super(ExpLR, self).__init__(optimizer, lambda s: gamma**(s / step_size), last_step)
43+
44+
45+
def initialize_optimizer(params, config):
46+
assert config.optimizer in ['SGD', 'Adagrad', 'Adam', 'RMSProp', 'Rprop', 'SGDLars']
47+
48+
if config.optimizer == 'SGD':
49+
return SGD(
50+
params,
51+
lr=config.lr,
52+
momentum=config.sgd_momentum,
53+
dampening=config.sgd_dampening,
54+
weight_decay=config.weight_decay)
55+
elif config.optimizer == 'Adam':
56+
return Adam(
57+
params,
58+
lr=config.lr,
59+
betas=(config.adam_beta1, config.adam_beta2),
60+
weight_decay=config.weight_decay)
61+
else:
62+
logging.error('Optimizer type not supported')
63+
raise ValueError('Optimizer type not supported')
64+
65+
66+
def initialize_scheduler(optimizer, config, last_step=-1):
67+
if config.scheduler == 'StepLR':
68+
return StepLR(
69+
optimizer, step_size=config.step_size, gamma=config.step_gamma, last_epoch=last_step)
70+
elif config.scheduler == 'PolyLR':
71+
return PolyLR(optimizer, max_iter=config.max_iter, power=config.poly_power, last_step=last_step)
72+
elif config.scheduler == 'SquaredLR':
73+
return SquaredLR(optimizer, max_iter=config.max_iter, last_step=last_step)
74+
elif config.scheduler == 'ExpLR':
75+
return ExpLR(
76+
optimizer, step_size=config.exp_step_size, gamma=config.exp_gamma, last_step=last_step)
77+
else:
78+
logging.error('Scheduler not supported')

0 commit comments

Comments
 (0)