Skip to content

Regression & Multitask#502

Open
kyoon-mit wants to merge 16 commits into
ML4GW:regression-datamodulesfrom
kyoon-mit:feature/regression_models
Open

Regression & Multitask#502
kyoon-mit wants to merge 16 commits into
ML4GW:regression-datamodulesfrom
kyoon-mit:feature/regression_models

Conversation

@kyoon-mit

@kyoon-mit kyoon-mit commented Jun 24, 2026

Copy link
Copy Markdown

Mirroring #452 by pushing into regression-datamodules branch.

kyoon-mit and others added 9 commits June 24, 2026 17:58
…avior for the s4kernel parameters while tracking the modifications in ml4gw.nn.ssm.s4d
…kground. Updated datasets, waveform generators and models to accept/return parameters as part of the batch
…lasses. Updated Autoencoder and Supervised models to inherit from AframeClassification. Added new SupervisedMultiTaskAframe and SupervisedRegressionAframe classes for multi-task and regression tasks.
- Introduced RegressionArchitecture and MultiTaskArchitecture classes in regression.py.
- Added RegressionTimeDomainResNet and MultiTaskTimeDomainResNet classes for regression tasks.
- Created multitask.yaml and regression.yaml for multi-task and regression training configurations.
- Enhanced BaseAframeDataset to support parameter transformations.
- Implemented ChirpMass and MassRatio transforms for parameter calculations.
- Updated AframeWandbLogger for improved logging capabilities.
@kyoon-mit kyoon-mit force-pushed the feature/regression_models branch from 282fa08 to f1ef5ca Compare June 25, 2026 00:41
GaussianNLLRegressionAframe predicts a mean and variance per parameter and
trains with a beta-weighted Gaussian NLL (BetaNLLLoss); detection score is
the negative mean variance. Includes an example S4D config.
Drop the removed length/prenorm/lr arguments and pass d_state.
Read metadata and per-rank slices via h5py/_load_with_idx instead of loading
whole ledgers into memory (WaveformLoader init and WaveformSampler val), add
a num_val_waveforms cap, and update the sampler test mock accordingly.
@kyoon-mit kyoon-mit force-pushed the feature/regression_models branch from 9121d37 to e217dfe Compare June 26, 2026 18:30
@kyoon-mit

Copy link
Copy Markdown
Author

Dependencies

This PR targets regression-datamodules and stacks on #503 (Switch S4D to ml4gw
import). It must be merged after #503, with this branch rebased onto the merged #503.

The base branch regression-datamodules already carries #450, #459, #460, #461.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant