-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathmath_functions.py
70 lines (55 loc) · 2.01 KB
/
math_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from scipy.sparse import csr_matrix
import torch
class SparseMM(torch.autograd.Function):
"""
Sparse x dense matrix multiplication with autograd support.
Implementation by Soumith Chintala:
https://discuss.pytorch.org/t/
does-pytorch-support-autograd-on-sparse-matrix/6156/7
"""
def forward(self, matrix1, matrix2):
self.save_for_backward(matrix1, matrix2)
return torch.mm(matrix1, matrix2)
def backward(self, grad_output):
matrix1, matrix2 = self.saved_tensors
grad_matrix1 = grad_matrix2 = None
if self.needs_input_grad[0]:
grad_matrix1 = torch.mm(grad_output, matrix2.t())
if self.needs_input_grad[1]:
grad_matrix2 = torch.mm(matrix1.t(), grad_output)
return grad_matrix1, grad_matrix2
def sparse_float_tensor(values, indices, size=None):
"""
Return a torch sparse matrix give values and indices (row_ind, col_ind).
If the size is an integer, return a square matrix with side size.
If the size is a torch.Size, use it to initialize the out tensor.
If none, the size is inferred.
"""
indices = torch.stack(indices).int()
sargs = [indices, values.float()]
if size is not None:
# Use the provided size
if isinstance(size, int):
size = torch.Size((size, size))
sargs.append(size)
if values.is_cuda:
return torch.cuda.sparse.FloatTensor(*sargs)
else:
return torch.sparse.FloatTensor(*sargs)
def diags(values, size=None):
values = values.view(-1)
n = values.nelement()
size = torch.Size((n, n))
indices = (torch.arange(0, n), torch.arange(0, n))
return sparse_float_tensor(values, indices, size)
def sparse_to_csr_matrix(tensor):
tensor = tensor.cpu()
inds = tensor._indices().numpy()
vals = tensor._values().numpy()
return csr_matrix((vals, (inds[0], inds[1])), shape=[s for s in tensor.shape])
def csr_matrix_to_sparse(mat):
row_ind, col_ind = mat.nonzero()
return sparse_float_tensor(
torch.from_numpy(mat.data),
(torch.from_numpy(row_ind), torch.from_numpy(col_ind)),
size=torch.Size(mat.shape))