Skip to content

Implement weighted samples and robust TRF (Issue #17)#21

Open
Hugo-W wants to merge 1 commit into
ai-revampfrom
issue-17-weighted-samples
Open

Implement weighted samples and robust TRF (Issue #17)#21
Hugo-W wants to merge 1 commit into
ai-revampfrom
issue-17-weighted-samples

Conversation

@Hugo-W

@Hugo-W Hugo-W commented May 23, 2026

Copy link
Copy Markdown
Owner

Summary

Implements weighted samples and robust TRF estimation as described in Issue #17.

Changes Made

utils.py

  • Added robust_loss() function: d(e) = log(1 + (e/σ)²) for robust error computation
  • Added robust_weights() function: computes weights for IRLS based on residuals
  • Added apply_sample_weights() function: applies W to X and y for weighted least squares (maps X to WX)

models.py (TRFEstimator)

  • Added weights parameter to fit() method for weighted samples
    • When weights are provided, applies them to X and y (WX transformation)
    • Uses same solver as usual after weighting
  • Added robust parameter for robust TRF with IRLS
    • Uses log-based error function for outlier resistance
    • Implements Iteratively Reweighted Least Squares (IRLS)
  • Added solver parameter to __init__(): 'auto', 'svd', or 'cg'
  • Added _solve_trf() helper method to choose between SVD and Conjugate Gradient
  • Updated _fitlists() to support weights and robust mode
  • Conjugate Gradient solver used for robust TRF (iterative reweighting)

Usage Example

# Weighted samples
weights = np.exp(-0.1 * np.arange(n_samples))  # temporal decay
trf = TRFEstimator(tmin=-0.5, tmax=1.2, srate=125)
trf.fit(X, y, weights=weights)

# Robust TRF with IRLS
trf = TRFEstimator(tmin=-0.5, tmax=1.2, srate=125)
trf.fit(X, y, robust=True, max_irls_iter=10, robust_sigma=1.0)

# Access standardized coefficients
print(trf.standardized_coef_)

Mathematical Background

  • Weighted samples: min ||W(y - Xβ)||² → transform to ||W^(1/2)y - W^(1/2)Xβ||²
  • Robust TRF: Uses d(e) = log(1 + (e/σ)²) to reduce outlier influence
  • IRLS: Iteratively reweight based on residuals until convergence

Testing

  • Basic weighted least squares with synthetic data
  • Robust TRF converges with IRLS
  • Conjugate Gradient solver works correctly
  • Backward compatibility (weights=None, robust=False)

Related Issues

Closes #17

- Add robust loss function d(e) = log(1 + (e/σ)²) to utils.py
- Add robust_weights() and apply_sample_weights() helpers
- Update TRFEstimator.fit() to accept weights and robust parameters
- Implement IRLS (Iteratively Reweighted Least Squares) for robust TRF
- Use conjugate gradient solver for iterative reweighted least squares
- Add _solve_trf() helper method to choose between SVD and CG solvers
- Update _fitlists() to support weights and robust mode
- Sample weights are applied as WX transformation (map X to WX)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant