Skip to content

Majeed7/QuadraSHAP

Repository files navigation

QuadraSHAP: Exact Shapley Values in Logarithmic Time for Trees and Kernels

alt text

This repository provides the official implementation accompanying the paper:

QuadraSHAP: Stable and Scalable Shapley Values for Product Games via Gauss-Legendre Quadrature

QuadraSHAP reformulates Shapley-value computation for product games as a Gauss-Legendre quadrature problem, yielding estimates that are both numerically stable and scalable to high-dimensional settings. The library covers two concrete application domains:

  • TreeExplainer: TreeSHAP-style explanations for scikit-learn tree models, with interchangeable numerical backends.
  • Product-kernel explainers: local Shapley values for models whose prediction function factorizes across features, such as RBF kernel methods.

The repository is organized as a research artifact: library code lives under src/, correctness tests under tests/, and benchmark scripts with precomputed outputs under benchmarks/.

Repository Structure

Path Description
src/quadrashap/ Package source code
src/quadrashap/treeshap/ Tree-model explainers and numerical backends
src/quadrashap/kernels/ Explainers for product-form kernel models
csrc/ Optional C++ extension for the quadrature-tree backend
tests/ Correctness and regression tests
benchmarks/ Scripts for runtime and approximation experiments
benchmarks/results/ Saved benchmark outputs and figures
model/ Cached models used by the text-classification benchmarks

Installation

The package requires Python >=3.11.

Using uv (recommended):

uv sync --extra jax --group testing

Using pip:

python -m venv .venv
source .venv/bin/activate
pip install -e .[jax]
pip install pytest pytest-benchmark scikit-learn shap

Notes

  • The build system attempts to compile the optional C++ extension if a compatible compiler is detected. If compilation fails, installation falls back gracefully to a pure-Python build.
  • JAX is optional for some backends, but jax and jaxlib are currently declared as core dependencies in pyproject.toml.

Quick Start

Tree models

quadrashap.TreeExplainer follows the familiar SHAP-style interface for supported scikit-learn tree models.

import numpy as np
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor

from quadrashap import TreeExplainer

X, y = make_regression(n_samples=300, n_features=6, random_state=0)
model = RandomForestRegressor(n_estimators=8, max_depth=4, random_state=0).fit(X, y)

explainer = TreeExplainer(model, tree_solver="product_games")
phi = explainer.shap_values(X[:10])

print(phi.shape)          # (10, 6)
print(explainer.expected_value)

Available tree backends (tree_solver):

Value Description
"product_games" TreeSHAP via product-game factorization
"quadrature_tree" Direct quadrature-tree backend

Useful options:

Option Values
backend_method "numpy_prefix_scan", "numpy_logspace", "jax_prefix_scan", "jax_logspace"
m_q Number of quadrature nodes (integer)
use_cpp True / False (quadrature-tree backend only)

Current limitations:

  • Only model_output="raw" is supported.
  • Only feature_perturbation="tree_path_dependent" is implemented.
  • The frontend currently targets scikit-learn tree estimators.

Product-kernel models

For kernel methods with factorized feature kernels, use RBFLocalExplainer or ProductKernelLocalExplainer.

import numpy as np
from sklearn.datasets import make_regression
from sklearn.kernel_ridge import KernelRidge

from quadrashap.kernels.explainer import RBFLocalExplainer

X, y = make_regression(n_samples=200, n_features=5, random_state=0)
model = KernelRidge(kernel="rbf", gamma=0.5, alpha=1.0).fit(X, y)

explainer = RBFLocalExplainer(model)
phi = explainer.explain(X[0], method="logspace_numpy")

print(phi.shape)          # (5,)

Supported kernel backends (method): logspace_numpy, logspace_jax, prefix_scan_numpy, prefix_scan_jax.

Running Tests

pytest tests

The test suite verifies:

  • agreement with naive Shapley implementations on small problems;
  • frontend conversion from scikit-learn trees to the internal unified format;
  • end-to-end agreement with shap.TreeExplainer on supported tree models;
  • optional C++ extension behavior.

Reproducing Experiments

All benchmark scripts are run from the repository root.

1. Quadrature-node convergence for kernel explainers

Generate raw convergence data:

python benchmarks/bench_mq_sweep.py

Aggregate and plot results:

python benchmarks/plot_mq_results.py

Outputs are written to benchmarks/results/mq/.

2. TreeSHAP runtime benchmark

python benchmarks/treeshap_bench.py

Compares several TreeSHAP implementations across varying tree sizes. Results are saved to benchmarks/treeshap_bench_results.json.

3. Text-classification benchmark

python benchmarks/text_classification_benchmark.py

Evaluates tree and kernel explainers on TF-IDF text-classification setups. Outputs are written to benchmarks/results/text_clf/.

Additional dependencies may be required: datasets, pandas, matplotlib, scipy, joblib, and optionally optuna.

Precomputed Results

Saved benchmark artifacts are included for inspection without rerunning experiments:

  • benchmarks/results/mq/ — convergence CSVs and figures for the quadrature-node sweep
  • benchmarks/results/text/ — tables and plots from the text-classification benchmark

Implementation Notes

  • The package uses scikit-build-core and pybind11 for the optional C++ extension.
  • Tree explanations are computed via an internal unified tree representation converted from scikit-learn models.
  • Kernel explainers use Gauss-Legendre quadrature with a configurable number of nodes m_q; when unset, a default is chosen based on the feature dimension.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors