-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pairplot refactoring #1529
base: main
Are you sure you want to change the base?
pairplot refactoring #1529
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @danielmk, great effort and really nice additions! I think this PR would already make life much easier for plotting SBI results. I have left some quite some comments, some more minor than others, so I'll provide a summary here:
-
I completely agree that it would be nice to expose the new dataclasses to the user so that they can use autocomplete, but happy to do this in a future PR. For now, already the fact that the user can import and use these, as well as the fact that we have a warning in
pairplot
andmarginal_plot
if any user-providedkwargs
are ignored is already great. Maybe a small addition in this PR would to add all the new dataclasses to the__init__
insbi/analysis
so that the user does not have to import them all explicitly to be able to use them. But maybe @janfb has some ideas for how to make the new dataclasses more easy to use directly by the user. -
Currently, you've had to add a lot of
#pyright: ignore
. Sometimes, there is no workaround for these, but when we have to add a lot of these ignore statements it means that we are probably doing something wrong. In the case of this PR, I think a lot of these can be avoided by updating the type hints (e.g. ,prepare_for_plot
now explicitly returns annp.ndarray
, and the code for other functions that callprepare_for_plot
assume that they get annp.ndarray
, but the type hint for whatprepare_for_plot
returns is stillList[np.ndarray]
. I expect that correcting this will allow us to remove a lot of the pyright errors. I have commented individually on these pyright errors, but not all. Would be good to double check whichignore
statements are strictly necessary and which we can remove with updated type hints. -
regarding the tests currently failing. As discussed separately, locally running:
ruff format sbi
ruff format tests
ruff check sbi --fix
ruff check tests --fix
pyright sbi
should fix the linting/pyright errors. I also see that the test suite cancels after a lot of tests fail. Can you check by running plot_test.py
locally to see if this is related to your changes or not? I think it's likely something about the new types that can be quickly fixed if we can see what the error message is 😄
@@ -81,8 +270,10 @@ def plt_hist_1d( | |||
limits: torch.Tensor, | |||
diag_kwargs: Dict, | |||
) -> None: | |||
# ax.hist(samples, **diag_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a leftover from debugging?
samples[i], copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan | ||
) | ||
samples[i] = samples[i][~np.isnan(samples[i]).any(axis=1)] | ||
# for i in range(len(samples)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove these entirely instead of commenting them out
|
||
@dataclass | ||
class GenericMplKwargs(GenericKwargs): | ||
"""MplKwargs is used to generate kwargs that are passed to matplotlib in pairplot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring is written with MplKwargs
, but the class is called GenericMplKwargs
epsilon_range = eps * max_min_range | ||
limits.append([min_val - epsilon_range, max_val + epsilon_range]) | ||
return limits | ||
# limits = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove instead of commenting out
@@ -612,21 +806,23 @@ def prepare_for_plot( | |||
of the samples. | |||
""" | |||
|
|||
samples = convert_to_list_of_numpy(samples) | |||
# samples = convert_to_list_of_numpy(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove instead of commenting out
@@ -883,9 +1171,10 @@ def marginal_plot( | |||
labels: List of strings specifying the names of the parameters. | |||
ticks: Position of the ticks. | |||
diag_kwargs: Additional arguments to adjust the diagonal plot, | |||
see the source code in `_get_default_diag_kwarg()` | |||
see the source code in `DiagKwargsKDE`, `DiagKwargsHist` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, TNum
is not used for this function!
fig_kwargs_filled = _get_default_fig_kwargs() | ||
fig_kwargs_filled = _update(fig_kwargs_filled, fig_kwargs) | ||
fig_kwargs_default = FigKwargs() # Get defaults | ||
#if type(fig_kwargs) == FigKwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover from debugging?
|
||
Returns: | ||
Fig: matplotlib figure | ||
Axes: matplotlib axes | ||
""" | ||
dim = samples[0].shape[1] | ||
dim = samples.shape[1] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for samples
is still a List, which is why this breaks. For functions that are not exposed to the user, such as arrange_grid
, it's fine to change the type hint without adding a warning, as the user is not meant to use this function directly.
ax, sample[:, row], limits[row], diag_kwargs[sample_ind] | ||
) | ||
|
||
# for sample_ind, sample in enumerate(samples.T): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove instead of commenting out
@@ -1447,16 +1654,16 @@ def _arrange_grid( | |||
if excl_lower: | |||
ax.axis("off") # pyright: ignore reportOptionalMemberAccess | |||
else: | |||
for sample_ind, sample in enumerate(samples): | |||
lower_f = lower_funcs[sample_ind] | |||
for _, _ in enumerate(samples.T): # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
samples
is still a list here
This PR is my (delayed) contribution to the 2025 hackathon, where I tried to resolve issues with the user interface of the widely used
pairplot
function, as in #1425The PR addresses the following main issue:
_get_default_diag_kwargs
, which returned dictionaries. I replaced those functions with dataclasses that contain the default values. These dataclasses can easily be converted to dictionaries by callingdict(FigKwargs)
so few changes are required internally and users can still use the standard way of passing kwargs as dictionaries. But the dataclasses are considered more Pythonic, they expose their internals more clearly to the user and in the future they could be passed topairplot
instead of dictionaries to specify keyword arguments. Although I currently don't know how to best make them available to the user, since they need to be explicitly imported right now. But that could be part of a future PR.samples
passed by the user were converted internally to a list of numpy arrays. Instead I now callnp.ndarray(samples)
, which creates a copy if necessary but changes nothing if samples are already a numpy array. IMO passing samples as ndarray should be strongly encouraged and it should either be an np.ndarray or a torch.Tensor. But for now lists are also supported for user flexibility.diag_kwarg
upper_kwargs
lower_kwargs
diag
,upper
andlower
all accept lists most likely with the intention that the user could chose a different plot type and different parameters for each plot. However, this was actually not working in the main branch. Instead only the first entry was used. I added user warnings that warn the user about this when they pass a list for any of these arguments. We should consider if this feature is actually desired. If not, the code could be massively simplified.{'mpl_kwargs': {}}
, where only the entries inmpl_kwargs
are actually passed to matplotlib. So{'bins':10, 'mpl_kwargs': {}}
, the'bins'
entry was siltently ignored. Instead,{'mpl_kwargs': {'bins':10}}
would be required. If any entries in anykwargs
is known to be ignored downstream, the user receives a warning about his issue. This is achieved by comparing the user provided dict with the parameter defined in the default dataclasses.There are still many issues with pairplot.py IMO and I am open to describing them in separate issues and continue work on those.