Skip to content

Add Permutation Importance #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from

Conversation

mayer79
Copy link
Collaborator

@mayer79 mayer79 commented May 10, 2025

Implements #201

@mayer79 mayer79 self-assigned this May 10, 2025
@mayer79 mayer79 added the enhancement New feature or request label May 10, 2025
@mayer79 mayer79 marked this pull request as draft May 10, 2025 10:04
@mayer79 mayer79 changed the title Add compute_permutation_importance() Add Permutation Importance May 10, 2025
@mayer79
Copy link
Collaborator Author

mayer79 commented May 16, 2025

This is the current basic call:

import numpy as np
import polars as pl
from sklearn.linear_model import LinearRegression

from model_diagnostics.xai import plot_permutation_importance

rng = np.random.default_rng(1)
n = 1000

X = pl.DataFrame(
    {
        "area": rng.uniform(30, 120, n),
        "rooms": rng.choice([2.5, 3.5, 4.5], n),
        "age": rng.uniform(0, 100, n),
    }
)

y = X["area"] + 20 * X["rooms"] + rng.normal(0, 1, n)

model = LinearRegression()
model.fit(X, y)

_ = plot_permutation_importance(
    predict_function=model.predict,
    X=X,
    y=y,
)

image

The extended feature API allows to permute groups like this:

_ = plot_permutation_importance(
    predict_function=model.predict,
    features={"size": ["area", "rooms"], "age": "age"},
    X=X,
    y=y,
)

image

from model_diagnostics.scoring import SquaredError


def safe_copy(X):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to put safe_copy and safe_column_names into _utils.array and add tests for them. I think they cause the current CI failure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My local tests are failing for the Python 3.9 environment only (pandas and pyarrow). I will move the functions to _utils.array, draft some unit tests, and rename safe_column_names() to get_column_names().

@lorentzenchr
Copy link
Owner

This will be a great addition! Thanks @mayer79

@lorentzenchr
Copy link
Owner

The failing test is in the python 3.9 env with
numpy 1.22.0
polars 1.0.0
scipy 1.10.0
pandas 1.5.3
pyarrow 11.0.0

Could you check if increasing one of the versions fixes the problem, e.g. polars version?

@mayer79
Copy link
Collaborator Author

mayer79 commented May 24, 2025

The failing test is in the python 3.9 env with numpy 1.22.0 polars 1.0.0 scipy 1.10.0 pandas 1.5.3 pyarrow 11.0.0

Could you check if increasing one of the versions fixes the problem, e.g. polars version?

The following changes in the 3.9 env would be necessary. I don't know how much it would hurt to abandon pandas 1

  • pyarrow 11 -> 13
  • pandas 1.5 -> 2.0

I have added some additional unit tests and moved safe_copy() and get_column_names() to array.py.

@mayer79 mayer79 marked this pull request as ready for review May 29, 2025 08:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants