Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pairplot refactoring #1529

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

pairplot refactoring #1529

wants to merge 1 commit into from

Conversation

danielmk
Copy link
Collaborator

This PR is my (delayed) contribution to the 2025 hackathon, where I tried to resolve issues with the user interface of the widely used pairplot function, as in #1425

The PR addresses the following main issue:

  • The default parameters for kwargs were hardcoded in functions such as _get_default_diag_kwargs, which returned dictionaries. I replaced those functions with dataclasses that contain the default values. These dataclasses can easily be converted to dictionaries by calling dict(FigKwargs) so few changes are required internally and users can still use the standard way of passing kwargs as dictionaries. But the dataclasses are considered more Pythonic, they expose their internals more clearly to the user and in the future they could be passed to pairplot instead of dictionaries to specify keyword arguments. Although I currently don't know how to best make them available to the user, since they need to be explicitly imported right now. But that could be part of a future PR.
  • The samples passed by the user were converted internally to a list of numpy arrays. Instead I now call np.ndarray(samples), which creates a copy if necessary but changes nothing if samples are already a numpy array. IMO passing samples as ndarray should be strongly encouraged and it should either be an np.ndarray or a torch.Tensor. But for now lists are also supported for user flexibility.
  • diag_kwarg upper_kwargs lower_kwargs diag, upper and lower all accept lists most likely with the intention that the user could chose a different plot type and different parameters for each plot. However, this was actually not working in the main branch. Instead only the first entry was used. I added user warnings that warn the user about this when they pass a list for any of these arguments. We should consider if this feature is actually desired. If not, the code could be massively simplified.
  • The way kwargs are passed has caused confusion, because they are passed as a nested dictionary {'mpl_kwargs': {}}, where only the entries in mpl_kwargs are actually passed to matplotlib. So {'bins':10, 'mpl_kwargs': {}}, the 'bins' entry was siltently ignored. Instead, {'mpl_kwargs': {'bins':10}} would be required. If any entries in any kwargs is known to be ignored downstream, the user receives a warning about his issue. This is achieved by comparing the user provided dict with the parameter defined in the default dataclasses.

There are still many issues with pairplot.py IMO and I am open to describing them in separate issues and continue work on those.

@danielmk danielmk requested a review from gmoss13 March 24, 2025 08:53
Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @danielmk, great effort and really nice additions! I think this PR would already make life much easier for plotting SBI results. I have left some quite some comments, some more minor than others, so I'll provide a summary here:

  • I completely agree that it would be nice to expose the new dataclasses to the user so that they can use autocomplete, but happy to do this in a future PR. For now, already the fact that the user can import and use these, as well as the fact that we have a warning in pairplot and marginal_plot if any user-provided kwargs are ignored is already great. Maybe a small addition in this PR would to add all the new dataclasses to the __init__ in sbi/analysis so that the user does not have to import them all explicitly to be able to use them. But maybe @janfb has some ideas for how to make the new dataclasses more easy to use directly by the user.

  • Currently, you've had to add a lot of #pyright: ignore. Sometimes, there is no workaround for these, but when we have to add a lot of these ignore statements it means that we are probably doing something wrong. In the case of this PR, I think a lot of these can be avoided by updating the type hints (e.g. , prepare_for_plot now explicitly returns an np.ndarray, and the code for other functions that call prepare_for_plot assume that they get an np.ndarray, but the type hint for what prepare_for_plot returns is still List[np.ndarray]. I expect that correcting this will allow us to remove a lot of the pyright errors. I have commented individually on these pyright errors, but not all. Would be good to double check which ignore statements are strictly necessary and which we can remove with updated type hints.

  • regarding the tests currently failing. As discussed separately, locally running:

ruff format sbi
ruff format tests
ruff check sbi --fix
ruff check tests --fix
pyright sbi

should fix the linting/pyright errors. I also see that the test suite cancels after a lot of tests fail. Can you check by running plot_test.py locally to see if this is related to your changes or not? I think it's likely something about the new types that can be quickly fixed if we can see what the error message is 😄

@@ -81,8 +270,10 @@ def plt_hist_1d(
limits: torch.Tensor,
diag_kwargs: Dict,
) -> None:
# ax.hist(samples, **diag_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a leftover from debugging?

samples[i], copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan
)
samples[i] = samples[i][~np.isnan(samples[i]).any(axis=1)]
# for i in range(len(samples)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these entirely instead of commenting them out


@dataclass
class GenericMplKwargs(GenericKwargs):
"""MplKwargs is used to generate kwargs that are passed to matplotlib in pairplot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is written with MplKwargs, but the class is called GenericMplKwargs

epsilon_range = eps * max_min_range
limits.append([min_val - epsilon_range, max_val + epsilon_range])
return limits
# limits = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove instead of commenting out

@@ -612,21 +806,23 @@ def prepare_for_plot(
of the samples.
"""

samples = convert_to_list_of_numpy(samples)
# samples = convert_to_list_of_numpy(samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove instead of commenting out

@@ -883,9 +1171,10 @@ def marginal_plot(
labels: List of strings specifying the names of the parameters.
ticks: Position of the ticks.
diag_kwargs: Additional arguments to adjust the diagonal plot,
see the source code in `_get_default_diag_kwarg()`
see the source code in `DiagKwargsKDE`, `DiagKwargsHist` and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, TNum is not used for this function!

fig_kwargs_filled = _get_default_fig_kwargs()
fig_kwargs_filled = _update(fig_kwargs_filled, fig_kwargs)
fig_kwargs_default = FigKwargs() # Get defaults
#if type(fig_kwargs) == FigKwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover from debugging?


Returns:
Fig: matplotlib figure
Axes: matplotlib axes
"""
dim = samples[0].shape[1]
dim = samples.shape[1] # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint for samples is still a List, which is why this breaks. For functions that are not exposed to the user, such as arrange_grid, it's fine to change the type hint without adding a warning, as the user is not meant to use this function directly.

ax, sample[:, row], limits[row], diag_kwargs[sample_ind]
)

# for sample_ind, sample in enumerate(samples.T):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove instead of commenting out

@@ -1447,16 +1654,16 @@ def _arrange_grid(
if excl_lower:
ax.axis("off") # pyright: ignore reportOptionalMemberAccess
else:
for sample_ind, sample in enumerate(samples):
lower_f = lower_funcs[sample_ind]
for _, _ in enumerate(samples.T): # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

samples is still a list here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants