-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save some memory in simplex constrain #3168
base: develop
Are you sure you want to change the base?
Conversation
Jenkins Console Log Machine informationNo LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 20.04.3 LTS Release: 20.04 Codename: focalCPU: G++: Clang: |
Nice find! This reminds me that we should put in the ILR based simplex instead of stick breaking. My only hold up is that I like to keep the stick breaking code as well. We haven't come to a consensus on how we add a different parameterization of the same constraint type. |
@bob-carpenter has stated before that he's of the opinion that we just completely replace stick-breaking with the ILR transform for simplexes, rather than provide an option Another ~easy option is to provide some compile time define to switch, rather than deciding on a language syntax level option |
The ILR one is super easy. We construct a zero_sum_vector (just like we have already) and then softmax it. There's a jacobian adjustment for the softmax.
Whatever is easiest. I'd like to save all the transforms in some directory. We can have the default ones be the ones we think are "best" from what we've seen and all the other ones stored separately. I think this requires a lot of repo organizing though. |
We need to choose a default and I think this should be it. If there are cases where we'd still want to use the old one, we can leave in the constraining/unconstraining functions for it and let people do it manually. |
Ah yea I saw that the constraint functions are getting exposed! Cool, that's works then. @WardBrian we can do this in 2 loops using the online softmax https://arxiv.org/abs/1805.02867. The first loop constructs the sum to zero, get the max value of that sum to zero vec, and returns the sum of exponentials with the max subtracted. The second loop does the safe exponential of the sum to zero vector with the max subtracted and divides by the sum of exponentials output from the first loop. The jacobian is output after. inline plain_type_t<Vec> simplex_ilr_constrain(const Vec& y, Lp& lp) {
const auto N = y.size();
plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
if (unlikely(N == 0)) {
return z;
}
auto&& y_ref = to_ref(y);
value_type_t<Vec> sum_w(0);
// new
double d = 0; // sum of exponentials
double max_val = 0;
double max_val_old = 0;
for (int i = N; i > 0; --i) {
double n = static_cast<double>(i);
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
sum_w += w;
z.coeffRef(i - 1) += sum_w;
z.coeffRef(i) -= w * n;
// new
max_val = max(max_val_old, z.coeff(i));
d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
max_val_old = max_val;
}
// new loop
for (int i = 0; i < N; ++i) {
z.coeffRef(i) = exp(z.coeff(i) - max_val) / d;
}
lp += -N * log(d) + 0.5 * log(N);
return z;
} |
Summary
I was looking at
simplex_constrain
's rev implementation and realized two things:z
s, but is otherwise the samez
in the reverse pass, anyway (at least for the version that takeslp
)So, some napkin math, we can save ~20% of the memory overhead of the rev implementation by just not storing these, and simplify the code by delegating the forward pass to prim.
The same is very nearly true for the stochastic matrices, except the forward pass inside the rev specialization is written in a more vectorized style than the prim one is, so it isn't a direct swap. It's still true that we are re-computing something we're also storing, so I still save the memory there, I just don't replace the forward with a call to prim.
If we want, I can move the vectorized stochastic code into prim and then do the other half of the change to rev for those functions as well.
Tests
Existing tests pass
Side Effects
None
Release notes
Reduce the memory overhead of the simplex constraints in reverse mode.
Checklist
Copyright holder: Simons Foundation
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make test-math-dependencies
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested