This repository provides the official implementation accompanying the paper:
QuadraSHAP: Stable and Scalable Shapley Values for Product Games via Gauss-Legendre Quadrature
QuadraSHAP reformulates Shapley-value computation for product games as a Gauss-Legendre quadrature problem, yielding estimates that are both numerically stable and scalable to high-dimensional settings. The library covers two concrete application domains:
TreeExplainer: TreeSHAP-style explanations for scikit-learn tree models, with interchangeable numerical backends.- Product-kernel explainers: local Shapley values for models whose prediction function factorizes across features, such as RBF kernel methods.
The repository is organized as a research artifact: library code lives under src/, correctness tests under tests/, and benchmark scripts with precomputed outputs under benchmarks/.
| Path | Description |
|---|---|
src/quadrashap/ |
Package source code |
src/quadrashap/treeshap/ |
Tree-model explainers and numerical backends |
src/quadrashap/kernels/ |
Explainers for product-form kernel models |
csrc/ |
Optional C++ extension for the quadrature-tree backend |
tests/ |
Correctness and regression tests |
benchmarks/ |
Scripts for runtime and approximation experiments |
benchmarks/results/ |
Saved benchmark outputs and figures |
model/ |
Cached models used by the text-classification benchmarks |
The package requires Python >=3.11.
Using uv (recommended):
uv sync --extra jax --group testingUsing pip:
python -m venv .venv
source .venv/bin/activate
pip install -e .[jax]
pip install pytest pytest-benchmark scikit-learn shapNotes
- The build system attempts to compile the optional C++ extension if a compatible compiler is detected. If compilation fails, installation falls back gracefully to a pure-Python build.
- JAX is optional for some backends, but
jaxandjaxlibare currently declared as core dependencies inpyproject.toml.
quadrashap.TreeExplainer follows the familiar SHAP-style interface for supported scikit-learn tree models.
import numpy as np
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from quadrashap import TreeExplainer
X, y = make_regression(n_samples=300, n_features=6, random_state=0)
model = RandomForestRegressor(n_estimators=8, max_depth=4, random_state=0).fit(X, y)
explainer = TreeExplainer(model, tree_solver="product_games")
phi = explainer.shap_values(X[:10])
print(phi.shape) # (10, 6)
print(explainer.expected_value)Available tree backends (tree_solver):
| Value | Description |
|---|---|
"product_games" |
TreeSHAP via product-game factorization |
"quadrature_tree" |
Direct quadrature-tree backend |
Useful options:
| Option | Values |
|---|---|
backend_method |
"numpy_prefix_scan", "numpy_logspace", "jax_prefix_scan", "jax_logspace" |
m_q |
Number of quadrature nodes (integer) |
use_cpp |
True / False (quadrature-tree backend only) |
Current limitations:
- Only
model_output="raw"is supported. - Only
feature_perturbation="tree_path_dependent"is implemented. - The frontend currently targets scikit-learn tree estimators.
For kernel methods with factorized feature kernels, use RBFLocalExplainer or ProductKernelLocalExplainer.
import numpy as np
from sklearn.datasets import make_regression
from sklearn.kernel_ridge import KernelRidge
from quadrashap.kernels.explainer import RBFLocalExplainer
X, y = make_regression(n_samples=200, n_features=5, random_state=0)
model = KernelRidge(kernel="rbf", gamma=0.5, alpha=1.0).fit(X, y)
explainer = RBFLocalExplainer(model)
phi = explainer.explain(X[0], method="logspace_numpy")
print(phi.shape) # (5,)Supported kernel backends (method): logspace_numpy, logspace_jax, prefix_scan_numpy, prefix_scan_jax.
pytest testsThe test suite verifies:
- agreement with naive Shapley implementations on small problems;
- frontend conversion from scikit-learn trees to the internal unified format;
- end-to-end agreement with
shap.TreeExplaineron supported tree models; - optional C++ extension behavior.
All benchmark scripts are run from the repository root.
Generate raw convergence data:
python benchmarks/bench_mq_sweep.pyAggregate and plot results:
python benchmarks/plot_mq_results.pyOutputs are written to benchmarks/results/mq/.
python benchmarks/treeshap_bench.pyCompares several TreeSHAP implementations across varying tree sizes. Results are saved to benchmarks/treeshap_bench_results.json.
python benchmarks/text_classification_benchmark.pyEvaluates tree and kernel explainers on TF-IDF text-classification setups. Outputs are written to benchmarks/results/text_clf/.
Additional dependencies may be required:
datasets,pandas,matplotlib,scipy,joblib, and optionallyoptuna.
Saved benchmark artifacts are included for inspection without rerunning experiments:
benchmarks/results/mq/— convergence CSVs and figures for the quadrature-node sweepbenchmarks/results/text/— tables and plots from the text-classification benchmark
- The package uses
scikit-build-coreandpybind11for the optional C++ extension. - Tree explanations are computed via an internal unified tree representation converted from scikit-learn models.
- Kernel explainers use Gauss-Legendre quadrature with a configurable number of nodes
m_q; when unset, a default is chosen based on the feature dimension.
