diff --git a/matrix_functions.py b/matrix_functions.py index 40e945c..948aa9d 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -104,7 +104,6 @@ def matrix_inverse_root( A=A, root=root, epsilon=epsilon, - make_positive_semidefinite=root_inv_config.make_positive_semidefinite, retry_double_precision=root_inv_config.retry_double_precision, eigen_decomp_offload_device=root_inv_config.eigen_decomp_offload_device, ) @@ -210,7 +209,6 @@ def _matrix_inverse_root_eigen( A: Tensor, root: Fraction, epsilon: float = 0.0, - make_positive_semidefinite: bool = True, retry_double_precision: bool = True, eigen_decomp_offload_device: torch.device | str = "", ) -> tuple[Tensor, Tensor, Tensor]: @@ -224,7 +222,6 @@ def _matrix_inverse_root_eigen( A (Tensor): Square matrix of interest. root (Fraction): Root of interest. Any rational number. epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0) - make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) retry_double_precision (bool): Flag for re-trying eigendecomposition with higher precision if lower precision fails due to CuSOLVER failure. (Default: True) eigen_decomp_offload_device (torch.device | str): Device to offload eigen decomposition computation. If value is empty string, do not perform offloading. (Default: "") @@ -248,14 +245,8 @@ def _matrix_inverse_root_eigen( eigen_decomp_offload_device=eigen_decomp_offload_device, ) - lambda_min = torch.min(L) - - # make eigenvalues >= 0 (if necessary) - if make_positive_semidefinite: - L += -torch.minimum(lambda_min, torch.as_tensor(0.0)) - - # add epsilon - L += epsilon + # make eigenvalues > 0 (if necessary) + L += -torch.minimum(torch.min(L) - epsilon, torch.as_tensor(0.0)) # compute inverse preconditioner X = Q * L.pow(torch.as_tensor(-1.0 / root)).unsqueeze(0) @ Q.T @@ -596,7 +587,6 @@ def compute_matrix_root_inverse_residuals( X_hat.double(), root=root, epsilon=0.0, - make_positive_semidefinite=True, eigen_decomp_offload_device=root_inv_config.eigen_decomp_offload_device, ) diff --git a/matrix_functions_types.py b/matrix_functions_types.py index dc71b7f..f42414f 100644 --- a/matrix_functions_types.py +++ b/matrix_functions_types.py @@ -54,13 +54,11 @@ class EigenConfig(RootInvConfig, EigenvalueDecompositionConfig): retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due to CuSOLVER failure. (Default: True) eigen_decomp_offload_device (torch.device | str): Device to offload eigen decomposition to. If value is empty string, we don't perform offloading. (Default: "") - make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: 1.0) """ - make_positive_semidefinite: bool = True exponent_multiplier: float = 1.0 diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 5c1cb98..3c0df94 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -333,7 +333,6 @@ def _test_eigen_root( self, A: torch.Tensor, root: int, - make_positive_semidefinite: bool, epsilon: float, tolerance: float, eig_sols: Tensor, @@ -342,7 +341,6 @@ def _test_eigen_root( A=A, root=Fraction(root), epsilon=epsilon, - make_positive_semidefinite=make_positive_semidefinite, ) abs_error = torch.dist(torch.linalg.matrix_power(X, -root), A, p=torch.inf) A_norm = torch.linalg.norm(A, ord=torch.inf) @@ -355,7 +353,6 @@ def _test_eigen_root_multi_dim( A: Callable[[int], Tensor], dims: list[int], roots: list[int], - make_positive_semidefinite: bool, epsilons: list[float], tolerance: float, eig_sols: Callable[[int], Tensor], @@ -365,7 +362,6 @@ def _test_eigen_root_multi_dim( self._test_eigen_root( A(n), root, - make_positive_semidefinite, epsilon, tolerance, eig_sols(n), @@ -376,7 +372,6 @@ def test_eigen_root_identity(self) -> None: dims = [10, 100] roots = [1, 2, 4, 8] epsilons = [0.0] - make_positive_semidefinite = False def eig_sols(n: int) -> Tensor: return torch.ones(n) @@ -384,16 +379,13 @@ def eig_sols(n: int) -> Tensor: def A(n: int) -> Tensor: return torch.eye(n) - self._test_eigen_root_multi_dim( - A, dims, roots, make_positive_semidefinite, epsilons, tolerance, eig_sols - ) + self._test_eigen_root_multi_dim(A, dims, roots, epsilons, tolerance, eig_sols) def test_eigen_root_tridiagonal_1(self) -> None: tolerance = 1e-4 dims = [10, 100] roots = [1, 2, 4, 8] epsilons = [0.0] - make_positive_semidefinite = False for alpha, beta in itertools.product( [0.001, 0.01, 0.1, 1.0, 10.0, 100.0], repeat=2 @@ -425,7 +417,6 @@ def A(n: int, alpha: float, beta: float) -> Tensor: partial(A, alpha=alpha, beta=beta), dims, roots, - make_positive_semidefinite, epsilons, tolerance, partial(eig_sols, alpha=alpha, beta=beta), @@ -436,7 +427,6 @@ def test_eigen_root_tridiagonal_2(self) -> None: dims = [10, 100] roots = [1, 2, 4, 8] epsilons = [0.0] - make_positive_semidefinite = False for alpha, beta in itertools.product( [0.001, 0.01, 0.1, 1.0, 10.0, 100.0], repeat=2 @@ -471,7 +461,6 @@ def A(n: int, alpha: float, beta: float) -> Tensor: partial(A, alpha=alpha, beta=beta), dims, roots, - make_positive_semidefinite, epsilons, tolerance, partial(eig_sols, alpha=alpha, beta=beta),