Skip to main content
The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. SHAP values explain a single prediction by attributing the prediction’s deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. This can be used to:
  • See which features drive model predictions.
  • Compare feature importance across samples.
  • Debug unexpected model behavior.
The extension also provides an easy interface for TabPFN Partial Dependence Plots and feature selection.
SHAP waterfall plotSHAP feature importance

Installation

pip install tabpfn-client "tabpfn-extensions[interpretability]"
This installs shapiq, shap, and the other dependencies needed for all three methods.

Quickstart

Train a model, explain a single prediction, and plot the result:
Interpretability computations are resource-intensive. This tutorial uses our API client. For fully local execution instead of the cloud API, replace the tabpfn_client import with tabpfn and ensure you have a GPU available. See best practices for GPU setup. All code examples below work identically with either backend.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tabpfn_client import TabPFNClassifier
from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer

X, y = load_iris(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = TabPFNClassifier()
clf.fit(X_train, y_train)

explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train)
sv = explainer.explain(X_test.iloc[0:1].values, budget=128)
sv.plot_waterfall()

Choosing a Method

Before diving into each method, here is a summary to help you pick the right tool for the question you are trying to answer.
shapiq vs SHAP — shapiq’s TabPFNExplainer removes features and re-contextualizes the model, which matches how TabPFN natively handles missing data. SHAP replaces absent features with random background samples. shapiq is faster and produces explanations that are more faithful to the TabPFN models. We recommend it as the default.
MethodWhat it tells youWhen to reach for it
shapiq (recommended)This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data.You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN.
SHAPPer-prediction feature attributions via imputation-based permutation.You need explanations that are directly comparable to SHAP values from other models (XGBoost, Random Forest, etc.), or you are already using the SHAP library in your workflow.
Partial Dependence / ICEThe global, marginal effect of one or two features across the entire dataset.You want to understand how a feature affects the model on average rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator.
Feature SelectionWhich minimal subset of features preserves model performance.You want to simplify your model or identify redundant features before deployment.
If you are still unsure which method to use, follow the table below to see the best tools for most common questions.
QuestionMethod
”Why did the model predict this for this sample?“shapiq — get_tabpfn_explainer
”Which feature pairs interact most?“shapiq — get_tabpfn_explainer with index="k-SII", max_order=2
”How does feature X affect predictions globally?”Partial Dependence — partial_dependence_plots
”I need SHAP values compatible with other models’ explanations”SHAP — get_shap_values or shapiq — get_tabpfn_imputation_explainer
”Which features can I drop without losing accuracy?”Feature Selection — feature_selection

Use Cases

Explain a prediction with shapiq

Use Shapley interaction indices to understand not just which features matter, but which feature pairs drive a prediction together.
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from tabpfn_client import TabPFNClassifier
from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = TabPFNClassifier()
clf.fit(X_train, y_train)

# k-SII captures pairwise feature interactions
explainer = get_tabpfn_explainer(
    model=clf,
    data=X_train,
    labels=y_train,
    index="k-SII",
    max_order=2,
)

sv = explainer.explain(X_test.iloc[0:1].values, budget=128)
print(sv)              # top interactions ranked by magnitude
sv.plot_waterfall()    # waterfall plot showing additive contributions

Visualize global feature effects with Partial Dependence Plots

PDP and ICE curves show how a feature affects predictions across the whole dataset, not just one sample.
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from tabpfn_client import TabPFNClassifier
from tabpfn_extensions.interpretability.pdp import partial_dependence_plots

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = TabPFNClassifier()
clf.fit(X_train, y_train)

# PDP for two features; set kind="individual" for ICE, or "both" for overlay
partial_dependence_plots(
    clf, X_test.values,
    features=[0, 1],
    kind="average",
    target_class=1,
)

Compare TabPFN vs other model explanations side-by-side

If you need to compare TabPFN explanations against SHAP explanations of another model using the exact same imputation strategy, use get_tabpfn_imputation_explainer. This wraps shapiq’s generic TabularExplainer with marginal imputation — the same approach the SHAP library uses.
from tabpfn_extensions.interpretability.shapiq import get_tabpfn_imputation_explainer

# Imputation-based explanation (same strategy as SHAP)
impute_explainer = get_tabpfn_imputation_explainer(
    model=clf,
    data=X_train,
    index="SV",
    max_order=1,
    imputer="marginal",
)
sv_impute = impute_explainer.explain(X_test.iloc[0:1].values, budget=128)

Feature selection

Sequential feature selection identifies the minimal subset of features that contributes most to model performance:
from tabpfn_extensions.interpretability.feature_selection import feature_selection

selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_select=5)
X_selected = selector.transform(X_test.values)
print("Selected feature indices:", selector.get_support(indices=True))

Controlling the budget parameter

The budget parameter in explainer.explain() sets how many coalition samples shapiq evaluates to approximate Shapley values. Each coalition is a subset of features — evaluating more of them produces more accurate estimates but costs more model calls. In theory, exact Shapley values require evaluating all 2^n feature subsets (e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq’s approximation algorithms converge well before that:
Dataset sizeSuggested budgetNotes
Few features (< 10)64128Converges quickly; low budgets are fine
Medium (10–20 features)128512Good accuracy/speed tradeoff
Many features (20+)5122048Higher budgets help, but returns diminish
Start low (e.g. budget=128) and increase only if the resulting explanations look noisy or unstable across repeated runs.

Compute SHAP values using shap library

The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to shapiq.
SHAP’s permutation explainer scales with the number of features. On datasets with many features, expect longer runtimes. For faster results, consider using shapiq or passing a smaller subset to get_shap_values.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tabpfn_client import TabPFNClassifier
from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap

X, y = load_iris(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = TabPFNClassifier()
clf.fit(X_train, y_train)

shap_values = get_shap_values(clf, X_test)

# Aggregate bar chart + per-sample beeswarm
plot_shap(shap_values)

Library Reference

interpretability.shapiq.get_tabpfn_explainer

Creates a shapiq TabPFNExplainer that uses the remove-and-recontextualize paradigm for TabPFN models.
ParameterTypeDefaultDescription
modelTabPFNClassifier | TabPFNRegressorrequiredFitted TabPFN model
dataDataFrame | ndarrayrequiredBackground / training data
labelsDataFrame | ndarrayrequiredLabels for the background data
indexstr"k-SII"Shapley index type. Options: "SV" (Shapley values), "k-SII" (k-Shapley interaction index), "SII", "FSII", "FBII", "STII". With max_order=1, "k-SII" reduces to standard Shapley values.
max_orderint2Maximum interaction order. Set to 1 for single-feature attributions only (no interactions).
class_indexint | NoneNoneClass to explain for classification models. Defaults to class 1 when None. Ignored for regression.
**kwargsAdditional keyword arguments forwarded to shapiq.TabPFNExplainer
Returns: shapiq.TabPFNExplainer Call .explain(x, budget=N) where x is a 2D numpy array of shape (1, n_features) and budget is the number of coalition samples to evaluate (see Controlling the budget parameter). Returns a shapiq.InteractionValues object with .plot_waterfall(), .plot_force(), and other visualization methods.

interpretability.shapiq.get_tabpfn_imputation_explainer

Creates a shapiq TabularExplainer that uses imputation-based feature removal (same strategy as SHAP).
ParameterTypeDefaultDescription
modelTabPFNClassifier | TabPFNRegressorrequiredFitted TabPFN model
dataDataFrame | ndarrayrequiredBackground data for imputation sampling
indexstr"k-SII"Shapley index type (same options as above)
max_orderint2Maximum interaction order
imputerstr"marginal"Imputation method. See shapiq docs for available options.
class_indexint | NoneNoneClass to explain (classification only)
**kwargsAdditional keyword arguments forwarded to shapiq.TabularExplainer
Returns: shapiq.TabularExplainer Same .explain(x, budget=N) interface as above.

interpretability.shap.get_shap_values

Computes SHAP values using a permutation-based explainer with automatic backend selection for TabPFN models.
ParameterTypeDefaultDescription
estimatorsklearn-compatible modelrequiredFitted model (TabPFN or any sklearn estimator)
test_xDataFrame | ndarray | TensorrequiredSamples to explain
attribute_nameslist[str] | NoneNoneFeature names when test_x is a numpy array
**kwargsForwarded to the underlying shap.Explainer
Returns: shap.Explanation — access .values for a numpy array of shape (n_samples, n_features) (regression) or (n_samples, n_features, n_classes) (classification).

interpretability.shap.plot_shap

Visualizes SHAP values as an aggregate bar chart, a per-sample beeswarm plot, and (if more than one sample) an interaction scatter for the most important feature.
ParameterTypeDefaultDescription
shap_valuesshap.ExplanationrequiredOutput from get_shap_values
Returns: None (displays matplotlib figures).

interpretability.pdp.partial_dependence_plots

Convenience wrapper around sklearn’s PartialDependenceDisplay.from_estimator.
ParameterTypeDefaultDescription
estimatorsklearn-compatible modelrequiredFitted estimator
XndarrayrequiredInput features
featureslist[int | tuple[int, int]]requiredFeature indices for 1D plots, or (i, j) tuples for 2D interaction plots
grid_resolutionint20Number of grid points per feature axis
kindstr"average""average" for PDP, "individual" for ICE curves, "both" for overlay
target_classint | NoneNoneFor classifiers: which class probability to plot
axmatplotlib.axes.Axes | NoneNoneOptional axes to plot into
**kwargsForwarded to PartialDependenceDisplay.from_estimator
Returns: sklearn.inspection.PartialDependenceDisplay

interpretability.feature_selection.feature_selection

Forward sequential feature selection using cross-validation.
ParameterTypeDefaultDescription
estimatorsklearn-compatible modelrequiredFitted estimator
XndarrayrequiredInput features
yndarrayrequiredTarget values
n_features_to_selectint3Number of features to select
feature_nameslist[str] | NoneNoneFeature names (optional)
**kwargsForwarded to sklearn.feature_selection.SequentialFeatureSelector
Returns: sklearn.feature_selection.SequentialFeatureSelector — call .transform(X) to reduce features, or .get_support(indices=True) to get selected indices.

Best Practices

GPU setup, batch inference, and performance tuning.

Classification

Binary and multi-class classification guide.

Regression

Point estimates, quantiles, and full distributions.

Fine-Tuning

Adapt TabPFN to your domain-specific data.