From b7f208a73b51aef1029c7d9d31eae8c4c4cc8e37 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 13:20:37 -0400 Subject: [PATCH 01/16] updating ml4gw version to 0.8.3 --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index b3700fae..5e2e711f 100644 --- a/uv.lock +++ b/uv.lock @@ -1911,7 +1911,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1920,9 +1920,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] From 1bf718db17628d8896d56242cd7aed824470467a Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 17:37:09 -0400 Subject: [PATCH 02/16] updating to ml4gw==0.8.3 in all the uv envs --- libs/architectures/uv.lock | 45 +++++++++++++++++++++++++++++++++++--- libs/ledger/uv.lock | 8 +++---- libs/p_astro/uv.lock | 8 +++---- libs/priors/uv.lock | 10 +++++---- libs/utils/uv.lock | 6 ++--- projects/data/uv.lock | 6 ++--- projects/export/uv.lock | 6 ++--- projects/infer/uv.lock | 8 +++---- projects/online/uv.lock | 6 ++--- projects/plots/uv.lock | 10 +++++---- projects/train/uv.lock | 6 ++--- 11 files changed, 81 insertions(+), 38 deletions(-) diff --git a/libs/architectures/uv.lock b/libs/architectures/uv.lock index 0d862327..82c5429a 100644 --- a/libs/architectures/uv.lock +++ b/libs/architectures/uv.lock @@ -174,17 +174,18 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.2" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, { name = "numpy" }, + { name = "scipy" }, { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/18/3d/9c58325b0f3d606cebc0e52e8b8600aeeca699096184a5ca0c4a2a2b3024/ml4gw-0.7.2.tar.gz", hash = "sha256:fc9f61fbc6e2fd9ae6b8654d4e0468788b56cf8631f1e4d8b4a55dba80931a90", size = 101166, upload-time = "2025-02-12T22:06:36.278Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/2d/efb9febf79d4b54392f8b45910fed080d78ea2ba377319e34919f9cda1cc/ml4gw-0.7.2-py3-none-any.whl", hash = "sha256:a8f2f508420eba1c6dc045762b1c6452ce51b43ffa3c138e2ca4bd8436540ff5", size = 123216, upload-time = "2025-02-12T22:06:35.171Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -390,6 +391,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083, upload-time = "2024-12-01T12:54:19.735Z" }, ] +[[package]] +name = "scipy" +version = "1.15.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/37/6964b830433e654ec7485e45a00fc9a27cf868d622838f6b6d9c5ec0d532/scipy-1.15.3.tar.gz", hash = "sha256:eae3cf522bc7df64b42cad3925c876e1b0b6c35c1337c93e12c0f366f55b0eaf", size = 59419214, upload-time = "2025-05-08T16:13:05.955Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/2f/4966032c5f8cc7e6a60f1b2e0ad686293b9474b65246b0c642e3ef3badd0/scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c", size = 38702770, upload-time = "2025-05-08T16:04:20.849Z" }, + { url = "https://files.pythonhosted.org/packages/a0/6e/0c3bf90fae0e910c274db43304ebe25a6b391327f3f10b5dcc638c090795/scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253", size = 30094511, upload-time = "2025-05-08T16:04:27.103Z" }, + { url = "https://files.pythonhosted.org/packages/ea/b1/4deb37252311c1acff7f101f6453f0440794f51b6eacb1aad4459a134081/scipy-1.15.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:aef683a9ae6eb00728a542b796f52a5477b78252edede72b8327a886ab63293f", size = 22368151, upload-time = "2025-05-08T16:04:31.731Z" }, + { url = "https://files.pythonhosted.org/packages/38/7d/f457626e3cd3c29b3a49ca115a304cebb8cc6f31b04678f03b216899d3c6/scipy-1.15.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:1c832e1bd78dea67d5c16f786681b28dd695a8cb1fb90af2e27580d3d0967e92", size = 25121732, upload-time = "2025-05-08T16:04:36.596Z" }, + { url = "https://files.pythonhosted.org/packages/db/0a/92b1de4a7adc7a15dcf5bddc6e191f6f29ee663b30511ce20467ef9b82e4/scipy-1.15.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:263961f658ce2165bbd7b99fa5135195c3a12d9bef045345016b8b50c315cb82", size = 35547617, upload-time = "2025-05-08T16:04:43.546Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/41991e503e51fc1134502694c5fa7a1671501a17ffa12716a4a9151af3df/scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2abc762b0811e09a0d3258abee2d98e0c703eee49464ce0069590846f31d40", size = 37662964, upload-time = "2025-05-08T16:04:49.431Z" }, + { url = "https://files.pythonhosted.org/packages/25/e1/3df8f83cb15f3500478c889be8fb18700813b95e9e087328230b98d547ff/scipy-1.15.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed7284b21a7a0c8f1b6e5977ac05396c0d008b89e05498c8b7e8f4a1423bba0e", size = 37238749, upload-time = "2025-05-08T16:04:55.215Z" }, + { url = "https://files.pythonhosted.org/packages/93/3e/b3257cf446f2a3533ed7809757039016b74cd6f38271de91682aa844cfc5/scipy-1.15.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5380741e53df2c566f4d234b100a484b420af85deb39ea35a1cc1be84ff53a5c", size = 40022383, upload-time = "2025-05-08T16:05:01.914Z" }, + { url = "https://files.pythonhosted.org/packages/d1/84/55bc4881973d3f79b479a5a2e2df61c8c9a04fcb986a213ac9c02cfb659b/scipy-1.15.3-cp310-cp310-win_amd64.whl", hash = "sha256:9d61e97b186a57350f6d6fd72640f9e99d5a4a2b8fbf4b9ee9a841eab327dc13", size = 41259201, upload-time = "2025-05-08T16:05:08.166Z" }, + { url = "https://files.pythonhosted.org/packages/96/ab/5cc9f80f28f6a7dff646c5756e559823614a42b1939d86dd0ed550470210/scipy-1.15.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:993439ce220d25e3696d1b23b233dd010169b62f6456488567e830654ee37a6b", size = 38714255, upload-time = "2025-05-08T16:05:14.596Z" }, + { url = "https://files.pythonhosted.org/packages/4a/4a/66ba30abe5ad1a3ad15bfb0b59d22174012e8056ff448cb1644deccbfed2/scipy-1.15.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:34716e281f181a02341ddeaad584205bd2fd3c242063bd3423d61ac259ca7eba", size = 30111035, upload-time = "2025-05-08T16:05:20.152Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/a7e5b95afd80d24313307f03624acc65801846fa75599034f8ceb9e2cbf6/scipy-1.15.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b0334816afb8b91dab859281b1b9786934392aa3d527cd847e41bb6f45bee65", size = 22384499, upload-time = "2025-05-08T16:05:24.494Z" }, + { url = "https://files.pythonhosted.org/packages/17/99/f3aaddccf3588bb4aea70ba35328c204cadd89517a1612ecfda5b2dd9d7a/scipy-1.15.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6db907c7368e3092e24919b5e31c76998b0ce1684d51a90943cb0ed1b4ffd6c1", size = 25152602, upload-time = "2025-05-08T16:05:29.313Z" }, + { url = "https://files.pythonhosted.org/packages/56/c5/1032cdb565f146109212153339f9cb8b993701e9fe56b1c97699eee12586/scipy-1.15.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:721d6b4ef5dc82ca8968c25b111e307083d7ca9091bc38163fb89243e85e3889", size = 35503415, upload-time = "2025-05-08T16:05:34.699Z" }, + { url = "https://files.pythonhosted.org/packages/bd/37/89f19c8c05505d0601ed5650156e50eb881ae3918786c8fd7262b4ee66d3/scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39cb9c62e471b1bb3750066ecc3a3f3052b37751c7c3dfd0fd7e48900ed52982", size = 37652622, upload-time = "2025-05-08T16:05:40.762Z" }, + { url = "https://files.pythonhosted.org/packages/7e/31/be59513aa9695519b18e1851bb9e487de66f2d31f835201f1b42f5d4d475/scipy-1.15.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:795c46999bae845966368a3c013e0e00947932d68e235702b5c3f6ea799aa8c9", size = 37244796, upload-time = "2025-05-08T16:05:48.119Z" }, + { url = "https://files.pythonhosted.org/packages/10/c0/4f5f3eeccc235632aab79b27a74a9130c6c35df358129f7ac8b29f562ac7/scipy-1.15.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:18aaacb735ab38b38db42cb01f6b92a2d0d4b6aabefeb07f02849e47f8fb3594", size = 40047684, upload-time = "2025-05-08T16:05:54.22Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a7/0ddaf514ce8a8714f6ed243a2b391b41dbb65251affe21ee3077ec45ea9a/scipy-1.15.3-cp311-cp311-win_amd64.whl", hash = "sha256:ae48a786a28412d744c62fd7816a4118ef97e5be0bee968ce8f0a2fba7acf3bb", size = 41246504, upload-time = "2025-05-08T16:06:00.437Z" }, + { url = "https://files.pythonhosted.org/packages/37/4b/683aa044c4162e10ed7a7ea30527f2cbd92e6999c10a8ed8edb253836e9c/scipy-1.15.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac6310fdbfb7aa6612408bd2f07295bcbd3fda00d2d702178434751fe48e019", size = 38766735, upload-time = "2025-05-08T16:06:06.471Z" }, + { url = "https://files.pythonhosted.org/packages/7b/7e/f30be3d03de07f25dc0ec926d1681fed5c732d759ac8f51079708c79e680/scipy-1.15.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:185cd3d6d05ca4b44a8f1595af87f9c372bb6acf9c808e99aa3e9aa03bd98cf6", size = 30173284, upload-time = "2025-05-08T16:06:11.686Z" }, + { url = "https://files.pythonhosted.org/packages/07/9c/0ddb0d0abdabe0d181c1793db51f02cd59e4901da6f9f7848e1f96759f0d/scipy-1.15.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:05dc6abcd105e1a29f95eada46d4a3f251743cfd7d3ae8ddb4088047f24ea477", size = 22446958, upload-time = "2025-05-08T16:06:15.97Z" }, + { url = "https://files.pythonhosted.org/packages/af/43/0bce905a965f36c58ff80d8bea33f1f9351b05fad4beaad4eae34699b7a1/scipy-1.15.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:06efcba926324df1696931a57a176c80848ccd67ce6ad020c810736bfd58eb1c", size = 25242454, upload-time = "2025-05-08T16:06:20.394Z" }, + { url = "https://files.pythonhosted.org/packages/56/30/a6f08f84ee5b7b28b4c597aca4cbe545535c39fe911845a96414700b64ba/scipy-1.15.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05045d8b9bfd807ee1b9f38761993297b10b245f012b11b13b91ba8945f7e45", size = 35210199, upload-time = "2025-05-08T16:06:26.159Z" }, + { url = "https://files.pythonhosted.org/packages/0b/1f/03f52c282437a168ee2c7c14a1a0d0781a9a4a8962d84ac05c06b4c5b555/scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271e3713e645149ea5ea3e97b57fdab61ce61333f97cfae392c28ba786f9bb49", size = 37309455, upload-time = "2025-05-08T16:06:32.778Z" }, + { url = "https://files.pythonhosted.org/packages/89/b1/fbb53137f42c4bf630b1ffdfc2151a62d1d1b903b249f030d2b1c0280af8/scipy-1.15.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6cfd56fc1a8e53f6e89ba3a7a7251f7396412d655bca2aa5611c8ec9a6784a1e", size = 36885140, upload-time = "2025-05-08T16:06:39.249Z" }, + { url = "https://files.pythonhosted.org/packages/2e/2e/025e39e339f5090df1ff266d021892694dbb7e63568edcfe43f892fa381d/scipy-1.15.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0ff17c0bb1cb32952c09217d8d1eed9b53d1463e5f1dd6052c7857f83127d539", size = 39710549, upload-time = "2025-05-08T16:06:45.729Z" }, + { url = "https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed", size = 40966184, upload-time = "2025-05-08T16:06:52.623Z" }, +] + [[package]] name = "setuptools" version = "75.8.0" diff --git a/libs/ledger/uv.lock b/libs/ledger/uv.lock index 6182d8e9..4048a55d 100644 --- a/libs/ledger/uv.lock +++ b/libs/ledger/uv.lock @@ -1612,7 +1612,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1621,9 +1621,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -2963,7 +2963,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/p_astro/uv.lock b/libs/p_astro/uv.lock index 363c3cd0..37b9bf49 100644 --- a/libs/p_astro/uv.lock +++ b/libs/p_astro/uv.lock @@ -1606,7 +1606,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1615,9 +1615,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -2989,7 +2989,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/priors/uv.lock b/libs/priors/uv.lock index 90545af6..4468a5d2 100644 --- a/libs/priors/uv.lock +++ b/libs/priors/uv.lock @@ -255,6 +255,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -705,7 +707,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -714,9 +716,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -1468,7 +1470,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index 1cfd3e70..d026e015 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -456,7 +456,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -465,9 +465,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/data/uv.lock b/projects/data/uv.lock index 4008a1c0..75432640 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -1496,7 +1496,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1505,9 +1505,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/export/uv.lock b/projects/export/uv.lock index a57e9dee..33ffc0d8 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -972,7 +972,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -981,9 +981,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/infer/uv.lock b/projects/infer/uv.lock index 5502aeeb..4d09280d 100644 --- a/projects/infer/uv.lock +++ b/projects/infer/uv.lock @@ -1897,7 +1897,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1906,9 +1906,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -3510,7 +3510,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/uv.lock b/projects/online/uv.lock index 6687b56a..d723cc8d 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -2302,7 +2302,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2311,9 +2311,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/plots/uv.lock b/projects/plots/uv.lock index d7ba6019..edfdb894 100644 --- a/projects/plots/uv.lock +++ b/projects/plots/uv.lock @@ -425,6 +425,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -2113,7 +2115,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2122,9 +2124,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -3889,7 +3891,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 27a12732..b54b462d 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -2321,7 +2321,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2330,9 +2330,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] From 1bfde0e57b4ff765687662d1fe6c7adb30f76303 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 17:44:49 -0400 Subject: [PATCH 03/16] removing old s4 code --- .../architectures/networks/__init__.py | 1 - .../architectures/networks/s4.py | 270 ------------------ 2 files changed, 271 deletions(-) delete mode 100644 libs/architectures/architectures/networks/s4.py diff --git a/libs/architectures/architectures/networks/__init__.py b/libs/architectures/architectures/networks/__init__.py index 62e5e0f3..f638cee6 100644 --- a/libs/architectures/architectures/networks/__init__.py +++ b/libs/architectures/architectures/networks/__init__.py @@ -1,3 +1,2 @@ -from .s4 import S4Model from .wavenet import WaveNet from .xylophone import Xylophone diff --git a/libs/architectures/architectures/networks/s4.py b/libs/architectures/architectures/networks/s4.py deleted file mode 100644 index b295c6de..00000000 --- a/libs/architectures/architectures/networks/s4.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Copied from https://github.com/chreissel/s4/blob/main/models/s4/s4d.py -and https://github.com/chreissel/s4/blob/main/src/models/nn/dropout.py - -Minimal version of S4D with extra options and features stripped out, -for pedagogical purposes. -""" - -import math -from typing import Optional - -import torch -import torch.nn as nn -from einops import rearrange, repeat - - -class DropoutNd(nn.Module): - def __init__(self, p: float = 0.5, tie=True, transposed=True): - """ - tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) - """ - super().__init__() - if p < 0 or p >= 1: - raise ValueError( - "dropout probability has to be in [0, 1), but got {}".format(p) - ) - self.p = p - self.tie = tie - self.transposed = transposed - self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) - - def forward(self, X): - """X: (batch, dim, lengths...).""" - if self.training: - if not self.transposed: - X = rearrange(X, "b ... d -> b d ...") - mask_shape = ( - X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape - ) - # mask = self.binomial.sample(mask_shape) - mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p - X = X * mask * (1.0 / (1 - self.p)) - if not self.transposed: - X = rearrange(X, "b d ... -> b ... d") - return X - return X - - -class S4DKernel(nn.Module): - """Generate convolution kernel from diagonal SSM parameters.""" - - def __init__( - self, - d_model: int, - length: int, - N: int = 64, - dt_min: float = 0.001, - dt_max: float = 0.1, - lr: float = None, - ): - super().__init__() - - # generate dt - H = d_model - log_dt = torch.rand(H) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - - C = torch.randn(H, N // 2, dtype=torch.cfloat) - self.C = nn.Parameter(torch.view_as_real(C)) - self.register("log_dt", log_dt, lr) - - log_A_real = torch.log(0.5 * torch.ones(H, N // 2)) - A_imag = math.pi * repeat(torch.arange(N // 2), "n -> h n", h=H) - self.register("log_A_real", log_A_real, lr) - self.register("A_imag", A_imag, lr) - - Ls = torch.arange(length) - self.register_buffer("length", Ls) - - def forward(self): - """ - returns: (..., c, L) where c is number of channels (default 1) - """ - - # Materialize parameters - dt = torch.exp(self.log_dt) # (H) - C = torch.view_as_complex(self.C) # (H N) - A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) - - # Vandermonde multiplication - dtA = A * dt.unsqueeze(-1) # (H N) - K = dtA.unsqueeze(-1) * self.length # (H N L) - C = C * (torch.exp(dtA) - 1.0) / A - K = 2 * torch.einsum("hn, hnl -> hl", C, torch.exp(K)).real - - return K - - def register(self, name, tensor, lr=None): - """ - Register a tensor with a configurable learning rate - and 0 weight decay - """ - - if lr == 0.0: - self.register_buffer(name, tensor) - else: - self.register_parameter(name, nn.Parameter(tensor)) - - optim = {"weight_decay": 0.0} - if lr is not None: - optim["lr"] = lr - getattr(self, name)._optim = optim - - -class S4D(nn.Module): - def __init__( - self, - d_model: int, - length: int, - d_state: int = 64, - dropout: float = 0.0, - transposed: bool = True, - dt_min: float = 0.001, - dt_max: float = 0.1, - lr: Optional[float] = None, - ): - super().__init__() - self.transposed = transposed - self.D = nn.Parameter(torch.randn(d_model)) - self.length = length - - # SSM Kernel - self.kernel = S4DKernel( - d_model, - length=length, - N=d_state, - dt_min=dt_min, - dt_max=dt_max, - lr=lr, - ) - - # Pointwise - self.activation = nn.GELU() - # TODO: investigate torch dropout implementation - self.dropout = torch.nn.Dropout1d(dropout) - # self.dropout = DropoutNd(dropout) if dropout > 0.0 else nn.Identity() - - # position-wise output transform to mix features - self.output_linear = nn.Sequential( - nn.Conv1d(d_model, 2 * d_model, kernel_size=1), - nn.GLU(dim=-2), - ) - - def forward(self, u): - """Input and output shape (B, H, L)""" - if not self.transposed: - u = u.transpose(-1, -2) - - # Compute SSM Kernel - k = self.kernel() # (H L) - - # Convolution - k_f = torch.fft.rfft(k, n=2 * self.length) # (H L) - u_f = torch.fft.rfft(u, n=2 * self.length) # (B H L) - y = torch.fft.irfft(u_f * k_f, n=2 * self.length)[ - ..., : self.length - ] # (B H L) - - # Compute D term in state space equation - # Essentially a skip connection - y = y + u * self.D.unsqueeze(-1) - - y = self.dropout(self.activation(y)) - y = self.output_linear(y) - if not self.transposed: - y = y.transpose(-1, -2) - # Return a dummy state to satisfy this repo's interface, - # but this can be modified - return y, None - - -class S4Model(nn.Module): - def __init__( - self, - d_input: int, - length: int, - d_output: int = 10, - d_model: int = 256, - d_state: int = 64, - n_layers: int = 4, - dropout: float = 0.2, - prenorm: bool = False, - dt_min: float = 0.001, - dt_max: float = 0.1, - lr: Optional[float] = None, - ): - super().__init__() - - self.prenorm = prenorm - - # Linear encoder (d_input = 1 for grayscale and 3 for RGB) - self.encoder = nn.Linear(d_input, d_model) - - # Stack S4 layers as residual blocks - self.s4_layers = nn.ModuleList() - self.norms = nn.ModuleList() - self.dropouts = nn.ModuleList() - if lr is not None: - lr = min(0.001, lr) - for _ in range(n_layers): - self.s4_layers.append( - S4D( - length=length, - d_model=d_model, - d_state=d_state, - dropout=dropout, - transposed=True, - dt_min=dt_min, - dt_max=dt_max, - lr=lr, - ) - ) - self.norms.append(nn.LayerNorm(d_model)) - self.dropouts.append(nn.Dropout1d(dropout)) - - # Linear decoder - self.decoder = nn.Linear(d_model, d_output) - - def forward(self, x): - """ - Input x is shape (B, d_input, L) - """ - x = x.transpose(-1, -2) - x = self.encoder(x) # (B, L, d_input) -> (B, L, d_model) - - x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L) - for layer, norm, dropout in zip( - self.s4_layers, self.norms, self.dropouts, strict=True - ): - # Each iteration of this loop will map - # (B, d_model, L) -> (B, d_model, L) - - z = x - if self.prenorm: - # Prenorm - z = norm(z.transpose(-1, -2)).transpose(-1, -2) - - # Apply S4 block: we ignore the state input and output - z, _ = layer(z) - - # Dropout on the output of the S4 block - z = dropout(z) - - # Residual connection - x = z + x - - if not self.prenorm: - # Postnorm - x = norm(x.transpose(-1, -2)).transpose(-1, -2) - - x = x.transpose(-1, -2) - - # Pooling: average pooling over the sequence length - x = x.mean(dim=1) - - # Decode the outputs - x = self.decoder(x) # (B, d_model) -> (B, d_output) - - return x From 05472c978c29af990d55d456013bd159c90e4f93 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 17:58:39 -0400 Subject: [PATCH 04/16] fixed S4Model import --- libs/architectures/architectures/supervised.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 99187183..ae336b38 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -1,10 +1,11 @@ from typing import Literal, Optional from architectures import Architecture -from architectures.networks import S4Model, WaveNet, Xylophone +from architectures.networks import WaveNet, Xylophone from jaxtyping import Float from ml4gw.nn.resnet.resnet_1d import NormLayer, ResNet1D from ml4gw.nn.resnet.resnet_2d import ResNet2D +from ml4gw.nn.ssm.s4d import S4Model from torch import Tensor import torch From dffe12b61116ba44495c1aa379e9d799932b358f Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 18:02:00 -0400 Subject: [PATCH 05/16] ensure ml4gw >= 0.8.3 floor --- libs/architectures/pyproject.toml | 2 +- libs/architectures/uv.lock | 2 +- libs/ledger/uv.lock | 2 +- libs/p_astro/uv.lock | 2 +- libs/priors/uv.lock | 2 +- libs/utils/pyproject.toml | 2 +- libs/utils/uv.lock | 2 +- projects/data/pyproject.toml | 2 +- projects/data/uv.lock | 4 ++-- projects/export/pyproject.toml | 2 +- projects/export/uv.lock | 4 ++-- projects/infer/uv.lock | 2 +- projects/online/pyproject.toml | 2 +- projects/online/uv.lock | 6 +++--- projects/plots/uv.lock | 2 +- projects/train/pyproject.toml | 2 +- projects/train/uv.lock | 6 +++--- 17 files changed, 23 insertions(+), 23 deletions(-) diff --git a/libs/architectures/pyproject.toml b/libs/architectures/pyproject.toml index 7fc7ecab..7606740f 100644 --- a/libs/architectures/pyproject.toml +++ b/libs/architectures/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ "einops>=0.8,<0.9", - "ml4gw>=0.7.2", + "ml4gw>=0.8.3", "h5py>=3.9.0,<4", "numpy~=1.26", ] diff --git a/libs/architectures/uv.lock b/libs/architectures/uv.lock index 82c5429a..012900c2 100644 --- a/libs/architectures/uv.lock +++ b/libs/architectures/uv.lock @@ -26,7 +26,7 @@ dev = [ requires-dist = [ { name = "einops", specifier = ">=0.8,<0.9" }, { name = "h5py", specifier = ">=3.9.0,<4" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "~=1.26" }, ] diff --git a/libs/ledger/uv.lock b/libs/ledger/uv.lock index 4048a55d..09007126 100644 --- a/libs/ledger/uv.lock +++ b/libs/ledger/uv.lock @@ -2963,7 +2963,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/p_astro/uv.lock b/libs/p_astro/uv.lock index 37b9bf49..112c8511 100644 --- a/libs/p_astro/uv.lock +++ b/libs/p_astro/uv.lock @@ -2989,7 +2989,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/priors/uv.lock b/libs/priors/uv.lock index 4468a5d2..025a2df2 100644 --- a/libs/priors/uv.lock +++ b/libs/priors/uv.lock @@ -1470,7 +1470,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/utils/pyproject.toml b/libs/utils/pyproject.toml index 70dbe806..dbc7f2b7 100644 --- a/libs/utils/pyproject.toml +++ b/libs/utils/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "h5py~=3.6", "numpy>=1.26.4,<2", "s3fs>=2024,<2025", - "ml4gw>=0.8.0", + "ml4gw>=0.8.3", "astropy>=6.0.1", ] diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index d026e015..1945698d 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -1124,7 +1124,7 @@ dev = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/data/pyproject.toml b/projects/data/pyproject.toml index 66088d56..08bb09ff 100644 --- a/projects/data/pyproject.toml +++ b/projects/data/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "utils", "ledger", "priors", - "ml4gw>=0.7.2", + "ml4gw>=0.8.3", "aframe", ] diff --git a/projects/data/uv.lock b/projects/data/uv.lock index 75432640..3239e37a 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -740,7 +740,7 @@ dev = [ requires-dist = [ { name = "aframe", editable = "../../" }, { name = "ledger", editable = "../../libs/ledger" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "priors", editable = "../../libs/priors" }, { name = "utils", editable = "../../libs/utils" }, ] @@ -2708,7 +2708,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/export/pyproject.toml b/projects/export/pyproject.toml index b95355d1..c4cf2441 100644 --- a/projects/export/pyproject.toml +++ b/projects/export/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }] requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ - "ml4gw>=0.8.0", + "ml4gw>=0.8.3", "boto3~=1.30", "fsspec[s3]>=2024,<2025", "ml4gw-hermes[torch]>=0.2.1", diff --git a/projects/export/uv.lock b/projects/export/uv.lock index 33ffc0d8..38e35f82 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -528,7 +528,7 @@ requires-dist = [ { name = "boto3", specifier = "~=1.30" }, { name = "fsspec", extras = ["s3"], specifier = ">=2024,<2025" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "ml4gw-hermes", extras = ["torch"], specifier = ">=0.2.1" }, { name = "nvidia-cudnn-cu11", specifier = "==8.9.6.50" }, { name = "tensorrt", specifier = "==8.5.2.2" }, @@ -1980,7 +1980,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/infer/uv.lock b/projects/infer/uv.lock index 4d09280d..1eb05378 100644 --- a/projects/infer/uv.lock +++ b/projects/infer/uv.lock @@ -3510,7 +3510,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/pyproject.toml b/projects/online/pyproject.toml index 0da8d616..f80dbe48 100644 --- a/projects/online/pyproject.toml +++ b/projects/online/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "architectures", "arrakis>=0.2.0,<0.3", "amplfi>=0.5.5", - "ml4gw>=0.7.4", + "ml4gw>=0.8.3", "omegaconf>=2.3.0,<3", "numpy<2.0.0", "scipy<1.15", diff --git a/projects/online/uv.lock b/projects/online/uv.lock index d723cc8d..ac8d0ade 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -214,7 +214,7 @@ dependencies = [ requires-dist = [ { name = "einops", specifier = ">=0.8,<0.9" }, { name = "h5py", specifier = ">=3.9.0,<4" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "~=1.26" }, ] @@ -2814,7 +2814,7 @@ requires-dist = [ { name = "ligo-gracedb", extras = ["kafka"], specifier = ">=2.15.4" }, { name = "ligo-skymap", specifier = ">=2.4.0,<3" }, { name = "matplotlib", specifier = "==3.9.4" }, - { name = "ml4gw", specifier = ">=0.7.4" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "<2.0.0" }, { name = "omegaconf", specifier = ">=2.3.0,<3" }, { name = "p-astro", editable = "../../libs/p_astro" }, @@ -4458,7 +4458,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/plots/uv.lock b/projects/plots/uv.lock index edfdb894..df664e5e 100644 --- a/projects/plots/uv.lock +++ b/projects/plots/uv.lock @@ -3891,7 +3891,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index b9cf9e73..2791acd1 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.8.0", + "ml4gw>=0.8.3", "aframe", "ledger", "priors", diff --git a/projects/train/uv.lock b/projects/train/uv.lock index b54b462d..d5890a5f 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -244,7 +244,7 @@ dependencies = [ requires-dist = [ { name = "einops", specifier = ">=0.8,<0.9" }, { name = "h5py", specifier = ">=3.9.0,<4" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "~=1.26" }, ] @@ -4373,7 +4373,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, @@ -4513,7 +4513,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] From cfd1c6a1ffc41c596a3446c644e7eab4c87f5d26 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 18:04:34 -0400 Subject: [PATCH 06/16] pre-commit fix --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 5e2e711f..c066fdc8 100644 --- a/uv.lock +++ b/uv.lock @@ -3688,7 +3688,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] From 0ef235a3c4c0dc252a2309edb08b4f443d50133b Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 19:51:58 -0400 Subject: [PATCH 07/16] modified s4 lightning module so that it preserves the separate lr behavior for the s4kernel parameters while tracking the modifications in ml4gw.nn.ssm.s4d --- projects/train/train/model/supervised.py | 81 ++++++++++++++---------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 8bad6068..4da68160 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -185,54 +185,67 @@ def validation_step(self, batch, _) -> None: class SupervisedAframeS4(SupervisedAframe): - def __init__(self, arch: SupervisedArchitecture, *args, **kwargs) -> None: + # S4D state-space kernel parameters: trained with a small learning rate + # and no weight decay. These names match the parameters registered by + # ml4gw's S4DKernel; edit this tuple to change which params receive the + # special learning rate. + SSM_PARAM_NAMES = ("log_dt", "log_A_real", "A_imag") + + def __init__( + self, + arch: SupervisedArchitecture, + *args, + ssm_lr: float = 1e-3, + **kwargs, + ) -> None: super().__init__(arch, *args, **kwargs) + self.save_hyperparameters("ssm_lr") def forward(self, X): return self.model(X) def configure_optimizers(self): """ - S4 requires a specific optimizer setup. - - The S4 layer (A, B, C, dt) parameters typically - require a smaller learning rate (typically 0.001), - with no weight decay. + Configure the optimizer and learning-rate scheduler. - The rest of the model can be trained with a higher learning rate - (e.g. 0.004, 0.01) and weight decay (if desired). + Parameters whose names appear in SSM_PARAM_NAMES are placed in their + own optimizer group with learning rate ssm_lr and zero weight decay. + All other parameters use learning_rate (scaled by the distributed + world size) and weight_decay. A cosine-annealing schedule decays both + groups from their base learning rates over the course of training. """ - if not torch.distributed.is_initialized(): - world_size = 1 - else: - world_size = torch.distributed.get_world_size() - - # All parameters in the model - all_parameters = list(self.model.parameters()) - - # General parameters don't contain the special _optim key - params = [p for p in all_parameters if not hasattr(p, "_optim")] - - # Create an optimizer with the general parameters + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) lr = self.hparams.learning_rate * world_size self._logger.info(f"Scaled lr by {world_size} to {lr}") - optimizer = torch.optim.AdamW( - params, lr=lr, weight_decay=self.hparams.weight_decay - ) - # Add parameters with special hyperparameters - hps = [p._optim for p in all_parameters if hasattr(p, "_optim")] - hps = [ - dict(s) - for s in sorted(dict.fromkeys(frozenset(hp.items()) for hp in hps)) - ] # Unique dicts - for hp in hps: - params = [ - p for p in all_parameters if getattr(p, "_optim", None) == hp + ssm_params, other_params = [], [] + for name, p in self.model.named_parameters(): + leaf = name.rsplit(".", 1)[-1] + if leaf in self.SSM_PARAM_NAMES: + ssm_params.append(p) + else: + other_params.append(p) + + optimizer = torch.optim.AdamW( + [ + { + "params": other_params, + "lr": lr, + "weight_decay": self.hparams.weight_decay, + }, + { + "params": ssm_params, + "lr": self.hparams.ssm_lr, + "weight_decay": 0.0, + }, ] - optimizer.add_param_group({"params": params, **hp}) + ) - # Create a lr scheduler + # Decay each group from its own base lr, preserving the lr ratio. scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, self.trainer.estimated_stepping_batches ) From b659bd83cea07bdacf0e185a42d737fb92b37117 Mon Sep 17 00:00:00 2001 From: Benedict Armstrong Date: Sun, 31 May 2026 14:25:56 +0200 Subject: [PATCH 08/16] Update Datasets to return Parameters alongside injected waveforms/background. Updated datasets, waveform generators and models to accept/return parameters as part of the batch --- projects/train/train/callbacks.py | 2 ++ projects/train/train/data/base.py | 1 + 2 files changed, 3 insertions(+) diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index ac115043..cffab890 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -53,6 +53,7 @@ def on_train_end(self, trainer, pl_module): [X] = next(iter(trainer.train_dataloader)) X = X.to(device) waveforms, params = trainer.datamodule.waveform_sampler.sample(X) + waveforms = trainer.datamodule.slice_waveforms(waveforms) X, y, _ = trainer.datamodule.inject(X, waveforms, params) if isinstance(X, tuple): X = tuple(i.cpu() for i in X) @@ -95,6 +96,7 @@ def on_train_start(self, trainer, pl_module): waveforms, params = trainer.datamodule.waveform_sampler.sample( X ) + waveforms = trainer.datamodule.slice_waveforms(waveforms) X, y, _ = trainer.datamodule.inject(X, waveforms, params) # If X is not a tuple, make it one for consistency # of format for saving to file below diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 0ce03f68..699ec421 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -585,6 +585,7 @@ def on_after_batch_transfer(self, batch, _): else: [X] = batch waveforms, params = self.waveform_sampler.sample(X) + waveforms = self.slice_waveforms(waveforms) batch = self.inject(X=X, waveforms=waveforms, params=params) elif self.trainer.validating or self.trainer.sanity_checking: # If we're in validation mode but we're not validating From b381f325201b8326f7896069b44c14183857d790 Mon Sep 17 00:00:00 2001 From: Benedict Armstrong Date: Sun, 31 May 2026 18:11:26 +0200 Subject: [PATCH 09/16] Refactor model structure to introduce classification and regression classes. Updated Autoencoder and Supervised models to inherit from AframeClassification. Added new SupervisedMultiTaskAframe and SupervisedRegressionAframe classes for multi-task and regression tasks. --- projects/train/train/model/__init__.py | 3 + projects/train/train/model/autoencoder.py | 4 +- projects/train/train/model/base.py | 55 ++----------- projects/train/train/model/classification.py | 51 ++++++++++++ projects/train/train/model/multitask.py | 70 ++++++++++++++++ projects/train/train/model/regression.py | 85 ++++++++++++++++++++ projects/train/train/model/supervised.py | 4 +- 7 files changed, 220 insertions(+), 52 deletions(-) create mode 100644 projects/train/train/model/classification.py create mode 100644 projects/train/train/model/multitask.py create mode 100644 projects/train/train/model/regression.py diff --git a/projects/train/train/model/__init__.py b/projects/train/train/model/__init__.py index 92e9cdd0..672d9750 100644 --- a/projects/train/train/model/__init__.py +++ b/projects/train/train/model/__init__.py @@ -1,5 +1,8 @@ from .autoencoder import AutoencoderAframe from .base import AframeBase +from .classification import AframeClassification +from .multitask import SupervisedMultiTaskAframe +from .regression import SupervisedRegressionAframe from .supervised import ( SupervisedAframe, SupervisedAframeS4, diff --git a/projects/train/train/model/autoencoder.py b/projects/train/train/model/autoencoder.py index e2f842ec..b9a20858 100644 --- a/projects/train/train/model/autoencoder.py +++ b/projects/train/train/model/autoencoder.py @@ -2,12 +2,12 @@ from architectures.autoencoder import AutoencoderArchitecture from ml4gw.transforms import ShiftedPearsonCorrelation -from train.model.base import AframeBase +from train.model.classification import AframeClassification Tensor = torch.Tensor -class AutoencoderAframe(AframeBase): +class AutoencoderAframe(AframeClassification): # TODO: include extra init arguments for the # various loss terms we might like to include. # If specific architectures have loss functions diff --git a/projects/train/train/model/base.py b/projects/train/train/model/base.py index b2cd5284..3c03fb70 100644 --- a/projects/train/train/model/base.py +++ b/projects/train/train/model/base.py @@ -6,8 +6,6 @@ import torch from architectures import Architecture -from train.metrics import TimeSlideAUROC - Tensor = torch.Tensor @@ -15,40 +13,31 @@ class AframeBase(pl.LightningModule): """ Args: arch: Architecture to train on - metric: Metric used for evaluation learning_rate: Hyperparameter controlling size of gradient steps during training pct_lr_ramp: Fraction of number of training epochs over which learning rate will ramp up to its specified value - patience: - Number of epochs to wait for an increase in - validation AUROC before terminating training. - If left as `None`, will never terminate - training early - save_top_k_models: - Maximum number of best-performing model checkpoints - to keep during training + weight_decay: + L2 regularisation strength + verbose: + Enable debug-level logging """ def __init__( self, arch: Architecture, - metric: TimeSlideAUROC, learning_rate: float, pct_lr_ramp: float, weight_decay: float = 0.0, verbose: bool = False, ) -> None: super().__init__() - # construct our model up front and record all - # our hyperparameters to our logdir; self.model = arch - self.metric = metric self.verbose = verbose self._logger = self.init_logging(verbose) - self.save_hyperparameters(ignore=["arch", "metric"]) + self.save_hyperparameters(ignore=["arch"]) def init_logging(self, verbose): log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -149,7 +138,7 @@ def training_step(self, batch: tuple[Tensor, Tensor]) -> Tensor: if isinstance(loss, dict): for name, value in loss.items(): self.log( - name, + f"train/{name}", value.mean(), on_step=True, on_epoch=True, @@ -160,7 +149,7 @@ def training_step(self, batch: tuple[Tensor, Tensor]) -> Tensor: loss = loss.mean() self.log( - "train_loss", + "train/loss", loss, on_step=True, on_epoch=True, @@ -169,36 +158,6 @@ def training_step(self, batch: tuple[Tensor, Tensor]) -> Tensor: ) return loss - def validation_step(self, batch, _) -> None: - shift, X_bg, X_inj, params = batch - - y_bg = self.score(X_bg) - - # compute predictions over multiple views of - # each injection and use their average as our - # prediction - num_views, batch, *shape = X_inj.shape - X_inj = X_inj.view(num_views * batch, *shape) - y_fg = self.score(X_inj) - y_fg = y_fg.view(num_views, batch) - y_fg = y_fg.mean(0) - - # include the shift associated with this data - # in our outputs to reconstruct background - # timeseries at aggregation time - self.metric.update(shift, y_bg, y_fg) - - # lightning will take care of updating then - # computing the metric at the end of the - # validation epoch - self.log( - "valid_auroc", - self.metric, - on_step=True, - on_epoch=True, - sync_dist=True, - ) - def configure_optimizers(self): if not torch.distributed.is_initialized(): world_size = 1 diff --git a/projects/train/train/model/classification.py b/projects/train/train/model/classification.py new file mode 100644 index 00000000..a15c87a1 --- /dev/null +++ b/projects/train/train/model/classification.py @@ -0,0 +1,51 @@ +from architectures import Architecture + +from train.metrics import TimeSlideAUROC +from train.model.base import AframeBase +import torch + +Tensor = torch.Tensor + + +class AframeClassification(AframeBase): + """ + Extends AframeBase with a TimeSlideAUROC validation step. + + All detection-oriented models (supervised, autoencoder, multi-task) + should inherit from this class. + + Args: + arch: Architecture to train on. + metric: TimeSlideAUROC instance used to evaluate detection performance. + """ + + def __init__( + self, + arch: Architecture, + metric: TimeSlideAUROC, + *args, + **kwargs, + ) -> None: + super().__init__(arch, *args, **kwargs) + self.metric = metric + + def validation_step(self, batch, _) -> None: + shift, X_bg, X_inj, params = batch + + y_bg = self.score(X_bg) + + num_views, batch_size, *shape = X_inj.shape + X_inj = X_inj.view(num_views * batch_size, *shape) + y_fg = self.score(X_inj) + y_fg = y_fg.view(num_views, batch_size).mean(0) + + self.metric.update(shift, y_bg, y_fg) + + metric_name = self.metric.__class__.__name__ + self.log( + f"validation/{metric_name}", + self.metric, + on_step=True, + on_epoch=True, + sync_dist=True, + ) diff --git a/projects/train/train/model/multitask.py b/projects/train/train/model/multitask.py new file mode 100644 index 00000000..c2dc4d70 --- /dev/null +++ b/projects/train/train/model/multitask.py @@ -0,0 +1,70 @@ +from typing import List + +import torch +import torch.nn.functional as F +from architectures import Architecture + +from train.metrics import TimeSlideAUROC +from train.model.supervised import SupervisedAframe + +Tensor = torch.Tensor + + +class SupervisedMultiTaskAframe(SupervisedAframe): + """ + Multi-task model that jointly optimizes binary classification + and injection parameter regression. + + The architecture's forward pass must return a tuple + ``(logits, param_estimates)`` where ``logits`` has shape ``(N, 1)`` + and ``param_estimates`` has shape ``(N, len(param_names))``. + + Validation uses the standard AUROC metric on the classification head, + so the existing validation infrastructure is fully reused. + + Args: + arch: + Architecture whose forward pass returns (logits, param_estimates). + param_names: + Ordered list of parameter names to regress on. Must match the + output ordering of the architecture's regression head. + regression_weight: + Scalar multiplier applied to the regression loss before summing + with the classification loss. + """ + + def __init__( + self, + arch: Architecture, + metric: TimeSlideAUROC, + param_names: List[str], + *args, + regression_weight: float = 1.0, + **kwargs, + ): + super().__init__(arch, metric, *args, **kwargs) + self.param_names = param_names + self.regression_weight = regression_weight + + def score(self, X: Tensor) -> Tensor: + logits, _ = self(X) + return logits + + def train_step(self, batch): + X, y, params = batch + logits, param_estimates = self(X) + + clf_loss = F.binary_cross_entropy_with_logits(logits, y) + + targets = torch.stack([params[k] for k in self.param_names], dim=1) + mask = ~torch.isnan(targets).any(dim=1) + reg_loss = ( + F.mse_loss(param_estimates[mask], targets[mask]) + if mask.any() + else torch.zeros(1, device=X.device) + ) + + return {"classification_loss": clf_loss, "regression_loss": reg_loss} + + def compute_loss_fn(self, classification_loss, regression_loss): + return classification_loss + self.regression_weight * regression_loss diff --git a/projects/train/train/model/regression.py b/projects/train/train/model/regression.py new file mode 100644 index 00000000..c531ba48 --- /dev/null +++ b/projects/train/train/model/regression.py @@ -0,0 +1,85 @@ +from typing import List +import warnings +import torch +import torch.nn.functional as F +from architectures import Architecture + + +from train.model.base import AframeBase + +Tensor = torch.Tensor + + +class SupervisedRegressionAframe(AframeBase): + """ + Supervised model that predicts injection parameters via regression only, + with no classification head or detection objective. + + Designed to be used with ``waveform_prob=1.0`` so that every training + sample contains an injection. Non-injected samples (NaN params) are + masked out of the loss gracefully, but they are wasteful. + + For validation set ``num_valid_views=1`` in the data config — multiple + views are only meaningful for averaging detection scores, not for + parameter recovery. + + The architecture's forward pass must return ``param_estimates`` of + shape ``(N, len(param_names))``. + + Args: + arch: + Architecture whose forward pass returns param_estimates. + param_names: + Ordered list of parameter names to regress on. Must match + the output ordering of the architecture's regression head. + """ + + def __init__( + self, + arch: Architecture, + param_names: List[str], + **kwargs, + ): + super().__init__(arch=arch, **kwargs) + self.param_names = param_names + + def forward(self, X: Tensor) -> Tensor: + return self.model(X) + + def score(self, X: Tensor) -> Tensor: + return self(X) + + def train_step(self, batch): + X, y, params = batch + mask = ~torch.isnan(next(iter(params.values()))) + + if not mask.any(): + warnings.warn( + "All samples in batch have NaN parameters;" + "skipping regression step.", + stacklevel=2, + ) + return torch.zeros(1, device=X.device, requires_grad=True) + + targets = torch.stack( + [params[k][mask] for k in self.param_names], dim=1 + ) + return F.mse_loss(self(X[mask]), targets) + + def validation_step(self, batch, _): + _, _, X_inj, params = batch + num_views, N, *shape = X_inj.shape + X_inj = X_inj.view(num_views * N, *shape) + + param_estimates = self(X_inj) + + for i, name in enumerate(self.param_names): + targets = params[name].repeat(num_views) + mae = F.l1_loss(param_estimates[:, i], targets) + self.log( + f"validation/mae_{name}", + mae, + on_step=False, + on_epoch=True, + sync_dist=True, + ) diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 4da68160..15a5bb7d 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -1,13 +1,13 @@ import torch from architectures.supervised import SupervisedArchitecture -from train.model.base import AframeBase +from train.model.classification import AframeClassification from train.metrics import TimeSlideAUROC Tensor = torch.Tensor -class SupervisedAframe(AframeBase): +class SupervisedAframe(AframeClassification): def __init__(self, arch: SupervisedArchitecture, *args, **kwargs) -> None: super().__init__(arch, *args, **kwargs) From 08bed37a9cd927771c3547be64a38256627e2d10 Mon Sep 17 00:00:00 2001 From: Benedict Armstrong Date: Sun, 31 May 2026 19:52:07 +0200 Subject: [PATCH 10/16] Add regression architecture and multi-task training configurations - Introduced RegressionArchitecture and MultiTaskArchitecture classes in regression.py. - Added RegressionTimeDomainResNet and MultiTaskTimeDomainResNet classes for regression tasks. - Created multitask.yaml and regression.yaml for multi-task and regression training configurations. - Enhanced BaseAframeDataset to support parameter transformations. - Implemented ChirpMass and MassRatio transforms for parameter calculations. - Updated AframeWandbLogger for improved logging capabilities. --- libs/architectures/architectures/__init__.py | 6 + .../architectures/architectures/regression.py | 120 ++++++++++++++ .../train/configs/regression/multitask.yaml | 146 ++++++++++++++++++ .../train/configs/regression/regression.yaml | 136 ++++++++++++++++ projects/train/train/callbacks.py | 55 ++++++- projects/train/train/data/base.py | 10 ++ .../train/train/data/supervised/supervised.py | 2 + projects/train/train/data/waveforms/loader.py | 29 +--- projects/train/train/model/base.py | 2 +- projects/train/train/transforms.py | 21 +++ 10 files changed, 503 insertions(+), 24 deletions(-) create mode 100644 libs/architectures/architectures/regression.py create mode 100644 projects/train/configs/regression/multitask.yaml create mode 100644 projects/train/configs/regression/regression.yaml create mode 100644 projects/train/train/transforms.py diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 350160f0..c3e22aaa 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -1,4 +1,10 @@ from .base import Architecture +from .regression import ( + MultiTaskArchitecture, + MultiTaskTimeDomainResNet, + RegressionArchitecture, + RegressionTimeDomainResNet, +) from .supervised import ( SupervisedArchitecture, SupervisedFrequencyDomainResNet, diff --git a/libs/architectures/architectures/regression.py b/libs/architectures/architectures/regression.py new file mode 100644 index 00000000..30a6becb --- /dev/null +++ b/libs/architectures/architectures/regression.py @@ -0,0 +1,120 @@ +from typing import Literal, Optional + +import torch +from ml4gw.nn.resnet.resnet_1d import NormLayer, ResNet1D + +from architectures import Architecture + + +class RegressionArchitecture(Architecture): + """ + Base class for regression architectures. + Forward pass returns a tensor of shape ``(N, num_params)``. + """ + + def forward(self, X: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class MultiTaskArchitecture(Architecture): + """ + Base class for multi-task architectures. + Forward pass returns ``(logits, param_estimates)`` where + ``logits`` has shape ``(N, 1)`` and ``param_estimates`` + has shape ``(N, num_params)``. + """ + + def forward( + self, X: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + +class RegressionTimeDomainResNet(ResNet1D, RegressionArchitecture): + """ + ResNet1D backbone for injection parameter regression. + Outputs a tensor of shape ``(N, num_params)``. + + Args: + num_params: + Number of parameters to regress. Must match the length + of ``param_names`` in the model config. + """ + + def __init__( + self, + num_ifos: int, + sample_rate: float, + kernel_length: float, + num_params: int, + layers: list[int], + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + ) -> None: + super().__init__( + num_ifos, + layers=layers, + classes=num_params, + kernel_size=kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + +class MultiTaskTimeDomainResNet(MultiTaskArchitecture): + """ + Shared ResNet1D backbone with separate classification and regression heads. + Returns ``(logits, param_estimates)`` where logits has shape ``(N, 1)`` + and param_estimates has shape ``(N, num_params)``. + + Args: + embedding_dim: + Dimensionality of the shared backbone's output embedding, + i.e. the input size to both heads. + num_params: + Number of parameters to regress. Must match the length + of ``param_names`` in the model config. + """ + + def __init__( + self, + num_ifos: int, + sample_rate: float, + kernel_length: float, + num_params: int, + layers: list[int], + embedding_dim: int = 512, + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + ) -> None: + super().__init__() + self.backbone = ResNet1D( + num_ifos, + layers=layers, + classes=embedding_dim, + kernel_size=kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + self.clf_head = torch.nn.Linear(embedding_dim, 1) + self.reg_head = torch.nn.Linear(embedding_dim, num_params) + + def forward( + self, X: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + h = self.backbone(X) + return self.clf_head(h), self.reg_head(h) diff --git a/projects/train/configs/regression/multitask.yaml b/projects/train/configs/regression/multitask.yaml new file mode 100644 index 00000000..c606a24d --- /dev/null +++ b/projects/train/configs/regression/multitask.yaml @@ -0,0 +1,146 @@ +# Multi-task training config. +# Jointly optimises binary classification (AUROC) and parameter regression. +# +# The shared ResNet backbone feeds into two heads: +# - classification head: single logit for detection +# - regression head: num_params outputs for parameter estimation +# +# Validation uses the standard AUROC metric on the classification head, +# so model selection and early stopping work the same as for classification. +# +# IMPORTANT: param_names and num_params must be consistent. + +model: + class_path: train.model.SupervisedMultiTaskAframe + init_args: + arch: + class_path: architectures.MultiTaskTimeDomainResNet + init_args: + layers: [3, 4, 6, 3] + # Shared backbone output dimension feeding both heads + embedding_dim: 512 + # Must equal len(param_names) below + num_params: 1 + norm_layer: + class_path: ml4gw.nn.norm.GroupNorm1DGetter + init_args: + groups: 16 + + metric: + class_path: train.metrics.TimeSlideAUROC + init_args: + max_fpr: 1e-3 + pool_length: 8 + + # Update to match the parameters available in your waveform files. + # Extrinsic params (dec, psi, phi, snr) are always available. + param_names: + - chirp_mass + + # Weight applied to the regression loss before summing with BCE. + # Tune this to balance the two tasks. + regression_weight: 1.0 + + # optimization + learning_rate: 1e-3 + pct_lr_ramp: 0.115 + weight_decay: 1e-4 + +data: + class_path: train.data.supervised.TimeDomainSupervisedAframeDataset + init_args: + background_dir: /home/barmstrong/aframe_new/data/bns/background_data + waveforms_dir: /home/barmstrong/aframe_new/data/bns/aframe_train/ + ifos: [H1, L1] + sample_rate: 2048 + + kernel_length: 2.0 + left_pad: 1.75 + right_pad: 0 + + batch_size: 256 + batches_per_epoch: 1000 + chunk_size: 10000 + chunks_per_epoch: 10 + num_files_per_batch: 4 + + fduration: 2.0 + fftlength: 2.0 + psd_length: 20 + highpass: 20 + lowpass: 1024 + + waveform_prob: 0.5 + swap_prob: 0.014 + mute_prob: 0.055 + max_num_workers: 16 + + valid_stride: 0.5 + num_valid_views: 5 + valid_livetime: 57600 + + snr_sampler: + class_path: train.augmentations.SnrSampler + init_args: + max_min_snr: 12.0 + min_min_snr: 4.0 + max_snr: 100.0 + alpha: 1.0 + decay_steps: 10000 + # distribution_type: uniform + + param_transforms: + - class_path: train.transforms.ChirpMass + + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader + init_args: + ifos: [H1, L1] + sample_rate: 2048 + val_waveform_file: /home/barmstrong/aframe_new/data/bns/aframe_train/val_waveforms.hdf5 + training_waveform_path: /home/barmstrong/aframe_new/data/bns/uniform_chirp_mass_waveforms/training_waveforms.hdf5 + + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: -3.14159 + high: 3.14159 + validate_args: false + +trainer: + logger: + - class_path: train.callbacks.AframeWandbLogger + init_args: + name: "aframe_multitask" + save_dir: outputs + project: "aframe" + + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: "validation/TimeSlideAUROC" + mode: "max" + patience: 50 + - class_path: train.callbacks.ModelCheckpoint + init_args: + monitor: "validation/TimeSlideAUROC" + mode: "max" + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + - class_path: train.callbacks.SaveAugmentedBatch + + accelerator: auto + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 25 + enable_progress_bar: true + benchmark: true diff --git a/projects/train/configs/regression/regression.yaml b/projects/train/configs/regression/regression.yaml new file mode 100644 index 00000000..e61ba50f --- /dev/null +++ b/projects/train/configs/regression/regression.yaml @@ -0,0 +1,136 @@ +# Regression-only training config. +# Predicts injection parameters directly from whitened strain. +# +# Key differences from the classification configs: +# - waveform_prob: 1.0 (every sample is injected) +# - num_valid_views: 1 (multi-view averaging is not useful for regression) +# - No AUROC metric; EarlyStopping monitors per-parameter MAE instead. +# +# IMPORTANT: param_names and num_params must be consistent. Set them to the +# parameter names present in the "parameters" group of your waveform HDF5 +# files plus any extrinsic params added by inject() (dec, psi, phi, snr). + +model: + class_path: train.model.SupervisedRegressionAframe + init_args: + arch: + class_path: architectures.RegressionTimeDomainResNet + init_args: + layers: [3, 4, 6, 3] + # Must equal len(param_names) below + num_params: 1 + norm_layer: + class_path: ml4gw.nn.norm.GroupNorm1DGetter + init_args: + groups: 16 + + # Update to match the parameters available in your waveform files. + # Extrinsic params (dec, psi, phi, snr) are always available. + param_names: + - chirp_mass + + # optimization + learning_rate: 1e-3 + pct_lr_ramp: 0.115 + weight_decay: 1e-4 + +data: + class_path: train.data.supervised.TimeDomainSupervisedAframeDataset + init_args: + background_dir: /home/barmstrong/aframe_new/data/bns/background_data + waveforms_dir: /home/barmstrong/aframe_new/data/bns/aframe_train/ + ifos: [H1, L1] + sample_rate: 2048 + + kernel_length: 2.0 + left_pad: 1.75 + right_pad: 0 + + batch_size: 256 + batches_per_epoch: 1000 + chunk_size: 10000 + chunks_per_epoch: 10 + num_files_per_batch: 4 + + fduration: 2.0 + fftlength: 2.0 + psd_length: 20 + highpass: 20 + lowpass: 1024 + + # Inject into every sample so the regression loss is never sparse. + waveform_prob: 1.0 + swap_prob: 0.0 + mute_prob: 0.0 + max_num_workers: 16 + + valid_stride: 0.5 + num_valid_views: 5 + valid_livetime: 57600 + + snr_sampler: + class_path: train.augmentations.SnrSampler + init_args: + max_min_snr: 12.0 + min_min_snr: 4.0 + max_snr: 100.0 + alpha: 1.0 + decay_steps: 10000 + # distribution_type: uniform + + param_transforms: + - class_path: train.transforms.ChirpMass + + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader + init_args: + ifos: [H1, L1] + sample_rate: 2048 + val_waveform_file: /home/barmstrong/aframe_new/data/bns/aframe_train/val_waveforms.hdf5 + training_waveform_path: /home/barmstrong/aframe_new/data/bns/uniform_chirp_mass_waveforms/training_waveforms.hdf5 + + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: -3.14159 + high: 3.14159 + validate_args: false + +trainer: + logger: + - class_path: train.callbacks.AframeWandbLogger + init_args: + name: "aframe_multitask" + save_dir: outputs + project: "aframe" + + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + # Monitor the first param's MAE; adjust to taste. + monitor: "validation/mae_chirp_mass" + mode: "min" + patience: 50 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: "validation/mae_chirp_mass" + mode: "min" + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + - class_path: train.callbacks.SaveAugmentedBatch + + accelerator: auto + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_progress_bar: true + benchmark: true diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index cffab890..9b221e37 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -1,7 +1,8 @@ import io import os import shutil -from typing import Optional +from typing import Optional, Union, Literal +from pathlib import Path import h5py import s3fs @@ -15,6 +16,58 @@ BOTO_RETRY_EXCEPTIONS = (ClientError, ConnectTimeoutError) +class AframeWandbLogger(WandbLogger): + """Thin wrapper around WandbLogger with clean type annotations. + + WandbLogger.__init__ uses ForwardRef('Run') / ForwardRef('RunDisabled') + for the `experiment` parameter, which breaks jsonargparse's get_type_hints + call. This subclass re-declares __init__ without that parameter so the + Lightning CLI can instantiate it from a YAML config. + """ + + def __init__( + self, + name: Optional[str] = None, + save_dir: Union[str, Path] = ".", + version: Optional[str] = None, + offline: bool = False, + dir: Optional[Union[str, Path]] = None, + id: Optional[str] = None, + anonymous: Optional[bool] = None, + project: Optional[str] = None, + log_model: Union[Literal["all"], bool] = False, + prefix: str = "", + checkpoint_name: Optional[str] = None, + ): + super().__init__( + name=name, + save_dir=save_dir, + version=version, + offline=offline, + dir=dir, + id=id, + anonymous=anonymous, + project=project, + log_model=log_model, + prefix=prefix, + checkpoint_name=checkpoint_name, + save_code=True, + ) + + self._offline = offline + + @property + def experiment(self): + # Accessing the parent's experiment property initializes the run + exp = super().experiment + # The run has a log_code method + if not getattr(self, "_code_logged", False): + if not getattr(self, "offline", False): + exp.log_code("train") + self._code_logged = True + return exp + + class WandbSaveConfig(pl.cli.SaveConfigCallback): """ Override of `lightning.pytorch.cli.SaveConfigCallback` for use with WandB diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 699ec421..71ab7afc 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -202,6 +202,7 @@ def __init__( chunks_per_epoch: int = 1, chunk_size: int = 10000, verbose: bool = False, + param_transforms: Optional[list[Callable]] = None, ) -> None: super().__init__() self.init_logging(verbose) @@ -223,6 +224,7 @@ def __init__( self._on_device = False self.dec, self.psi, self.phi = dec, psi, phi + self.param_transforms = param_transforms or [] self.waveform_sampler = waveform_sampler # If we're using a `WaveformLoader`, we're loading # training waveforms from disk, so have a flag tp @@ -605,9 +607,17 @@ def on_after_batch_transfer(self, batch, _): signals=signals, params=params, ) + params = self.apply_param_transforms(params) batch = (shift, X_bg, X_fg, params) return batch + def apply_param_transforms( + self, params: dict[str, Tensor] + ) -> dict[str, Tensor]: + for transform in self.param_transforms: + params = transform(params) + return params + @torch.no_grad() def inject(self, X: Tensor, waveforms: Tensor, params: dict[str, Tensor]): """ diff --git a/projects/train/train/data/supervised/supervised.py b/projects/train/train/data/supervised/supervised.py index c351d6c6..05b5e97a 100644 --- a/projects/train/train/data/supervised/supervised.py +++ b/projects/train/train/data/supervised/supervised.py @@ -114,4 +114,6 @@ def inject( out[idx[still_injected]] = vals[still_injected] params_out[key] = out + params_out = self.apply_param_transforms(params_out) + return X, y, psds, params_out diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index c0f3f40c..16263399 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -31,15 +31,11 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) if training_waveform_path.is_dir(): - self.training_waveform_files = list( - training_waveform_path.iterdir() - ) + self.training_waveform_files = list(training_waveform_path.iterdir()) else: self.training_waveform_files = [training_waveform_path] - waveform_set = WaveformPolarizationSet.read( - self.training_waveform_files[0] - ) + waveform_set = WaveformPolarizationSet.read(self.training_waveform_files[0]) if waveform_set.right_pad != self.right_pad: raise ValueError( "Training waveform file does not have the same " @@ -144,9 +140,7 @@ def __init__( "without using chunked storage. This can have " "severe performance impacts at data loading time. " "If you need faster loading, try re-generating " - "your datset with chunked storage turned on.".format( - fnames - ), + "your datset with chunked storage turned on.".format(fnames), stacklevel=2, ) @@ -186,17 +180,12 @@ def load_chunk(self, fname, start, size): channel: self.mmap_datasets[fname][channel][start:end] for channel in self.channels } - params = { - k: self.param_datasets[fname][k][start:end] - for k in self.param_keys - } + params = {k: self.param_datasets[fname][k][start:end] for k in self.param_keys} return waveforms, params def sample_batch(self): # allocate batch up front - batch = np.zeros( - (self.batch_size, self.num_channels, self.waveform_size) - ) + batch = np.zeros((self.batch_size, self.num_channels, self.waveform_size)) params_buf = { k: np.zeros( self.batch_size, @@ -208,9 +197,7 @@ def sample_batch(self): for i in range(self.chunks_per_batch): fname = np.random.choice(self.fnames, p=self.probs) - chunk_size = min( - self.chunk_size, self.batch_size - i * self.chunk_size - ) + chunk_size = min(self.chunk_size, self.batch_size - i * self.chunk_size) # select a random starting index for the chunk max_start = self.sizes[fname] - chunk_size @@ -277,9 +264,7 @@ def __iter__(self): def _next(it): [waveform_chunk], param_dict_chunk = next(it) - return waveform_chunk, { - k: v[0] for k, v in param_dict_chunk.items() - } + return waveform_chunk, {k: v[0] for k, v in param_dict_chunk.items()} waveform_chunk, param_chunk = _next(it) num_waveforms, _, _ = waveform_chunk.shape diff --git a/projects/train/train/model/base.py b/projects/train/train/model/base.py index 3c03fb70..83d916a0 100644 --- a/projects/train/train/model/base.py +++ b/projects/train/train/model/base.py @@ -37,7 +37,7 @@ def __init__( self.model = arch self.verbose = verbose self._logger = self.init_logging(verbose) - self.save_hyperparameters(ignore=["arch"]) + self.save_hyperparameters(ignore=["arch", "metric"]) def init_logging(self, verbose): log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/projects/train/train/transforms.py b/projects/train/train/transforms.py new file mode 100644 index 00000000..d0978139 --- /dev/null +++ b/projects/train/train/transforms.py @@ -0,0 +1,21 @@ +import torch + +Tensor = torch.Tensor + + +class ChirpMass: + """Adds ``chirp_mass`` computed from ``mass_1`` and ``mass_2``.""" + + def __call__(self, params: dict[str, Tensor]) -> dict[str, Tensor]: + m1, m2 = params["mass_1"], params["mass_2"] + params["chirp_mass"] = (m1 * m2) ** 0.6 / (m1 + m2) ** 0.2 + return params + + +class MassRatio: + """Adds ``mass_ratio`` = min(m1,m2) / max(m1,m2) in range (0, 1].""" + + def __call__(self, params: dict[str, Tensor]) -> dict[str, Tensor]: + m1, m2 = params["mass_1"], params["mass_2"] + params["mass_ratio"] = torch.minimum(m1, m2) / torch.maximum(m1, m2) + return params From eb325b91425a6ee9670978a9d05ce417b47f034e Mon Sep 17 00:00:00 2001 From: Benedict Armstrong Date: Sun, 31 May 2026 20:58:44 +0200 Subject: [PATCH 11/16] remove unnecessary slicing and fix param_keys in loader.py --- projects/train/train/callbacks.py | 2 -- projects/train/train/data/base.py | 1 - 2 files changed, 3 deletions(-) diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index 9b221e37..a5dfbed9 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -106,7 +106,6 @@ def on_train_end(self, trainer, pl_module): [X] = next(iter(trainer.train_dataloader)) X = X.to(device) waveforms, params = trainer.datamodule.waveform_sampler.sample(X) - waveforms = trainer.datamodule.slice_waveforms(waveforms) X, y, _ = trainer.datamodule.inject(X, waveforms, params) if isinstance(X, tuple): X = tuple(i.cpu() for i in X) @@ -149,7 +148,6 @@ def on_train_start(self, trainer, pl_module): waveforms, params = trainer.datamodule.waveform_sampler.sample( X ) - waveforms = trainer.datamodule.slice_waveforms(waveforms) X, y, _ = trainer.datamodule.inject(X, waveforms, params) # If X is not a tuple, make it one for consistency # of format for saving to file below diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 71ab7afc..053aa317 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -587,7 +587,6 @@ def on_after_batch_transfer(self, batch, _): else: [X] = batch waveforms, params = self.waveform_sampler.sample(X) - waveforms = self.slice_waveforms(waveforms) batch = self.inject(X=X, waveforms=waveforms, params=params) elif self.trainer.validating or self.trainer.sanity_checking: # If we're in validation mode but we're not validating From f1ef5ca650d06caf5da4f3481135687721bee6f2 Mon Sep 17 00:00:00 2001 From: Benedict Armstrong Date: Sun, 31 May 2026 22:53:57 +0200 Subject: [PATCH 12/16] fix configs --- projects/train/configs/regression/multitask.yaml | 12 ++++++------ projects/train/configs/regression/regression.yaml | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/projects/train/configs/regression/multitask.yaml b/projects/train/configs/regression/multitask.yaml index c606a24d..02796524 100644 --- a/projects/train/configs/regression/multitask.yaml +++ b/projects/train/configs/regression/multitask.yaml @@ -49,8 +49,8 @@ model: data: class_path: train.data.supervised.TimeDomainSupervisedAframeDataset init_args: - background_dir: /home/barmstrong/aframe_new/data/bns/background_data - waveforms_dir: /home/barmstrong/aframe_new/data/bns/aframe_train/ + background_dir: data/bns/background_data + waveforms_dir: data/bns/aframe_train/ ifos: [H1, L1] sample_rate: 2048 @@ -82,10 +82,10 @@ data: snr_sampler: class_path: train.augmentations.SnrSampler init_args: - max_min_snr: 12.0 + alpha: -3 + max_min_snr: 8.0 min_min_snr: 4.0 max_snr: 100.0 - alpha: 1.0 decay_steps: 10000 # distribution_type: uniform @@ -97,8 +97,8 @@ data: init_args: ifos: [H1, L1] sample_rate: 2048 - val_waveform_file: /home/barmstrong/aframe_new/data/bns/aframe_train/val_waveforms.hdf5 - training_waveform_path: /home/barmstrong/aframe_new/data/bns/uniform_chirp_mass_waveforms/training_waveforms.hdf5 + val_waveform_file: data/bns/aframe_train/val_waveforms.hdf5 + training_waveform_path: data/bns/uniform_chirp_mass_waveforms/training_waveforms.hdf5 dec: class_path: ml4gw.distributions.Cosine diff --git a/projects/train/configs/regression/regression.yaml b/projects/train/configs/regression/regression.yaml index e61ba50f..dd60358b 100644 --- a/projects/train/configs/regression/regression.yaml +++ b/projects/train/configs/regression/regression.yaml @@ -37,8 +37,8 @@ model: data: class_path: train.data.supervised.TimeDomainSupervisedAframeDataset init_args: - background_dir: /home/barmstrong/aframe_new/data/bns/background_data - waveforms_dir: /home/barmstrong/aframe_new/data/bns/aframe_train/ + background_dir: data/bns/background_data + waveforms_dir: data/bns/aframe_train/ ifos: [H1, L1] sample_rate: 2048 @@ -71,10 +71,10 @@ data: snr_sampler: class_path: train.augmentations.SnrSampler init_args: - max_min_snr: 12.0 + alpha: -3 + max_min_snr: 8.0 min_min_snr: 4.0 max_snr: 100.0 - alpha: 1.0 decay_steps: 10000 # distribution_type: uniform @@ -86,8 +86,8 @@ data: init_args: ifos: [H1, L1] sample_rate: 2048 - val_waveform_file: /home/barmstrong/aframe_new/data/bns/aframe_train/val_waveforms.hdf5 - training_waveform_path: /home/barmstrong/aframe_new/data/bns/uniform_chirp_mass_waveforms/training_waveforms.hdf5 + val_waveform_file: data/bns/aframe_train/val_waveforms.hdf5 + training_waveform_path: data/bns/uniform_chirp_mass_waveforms/training_waveforms.hdf5 dec: class_path: ml4gw.distributions.Cosine @@ -108,7 +108,7 @@ trainer: logger: - class_path: train.callbacks.AframeWandbLogger init_args: - name: "aframe_multitask" + name: "aframe_regression" save_dir: outputs project: "aframe" From 6159b521dacbb79599e348e599b4386bd0c0d640 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 20:45:16 -0400 Subject: [PATCH 13/16] ran pre-commit hook --- .../architectures/architectures/regression.py | 8 ++--- projects/train/train/data/waveforms/loader.py | 29 ++++++++++++++----- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/libs/architectures/architectures/regression.py b/libs/architectures/architectures/regression.py index 30a6becb..ff5a02f9 100644 --- a/libs/architectures/architectures/regression.py +++ b/libs/architectures/architectures/regression.py @@ -24,9 +24,7 @@ class MultiTaskArchitecture(Architecture): has shape ``(N, num_params)``. """ - def forward( - self, X: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -113,8 +111,6 @@ def __init__( self.clf_head = torch.nn.Linear(embedding_dim, 1) self.reg_head = torch.nn.Linear(embedding_dim, num_params) - def forward( - self, X: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: h = self.backbone(X) return self.clf_head(h), self.reg_head(h) diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index 16263399..c0f3f40c 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -31,11 +31,15 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) if training_waveform_path.is_dir(): - self.training_waveform_files = list(training_waveform_path.iterdir()) + self.training_waveform_files = list( + training_waveform_path.iterdir() + ) else: self.training_waveform_files = [training_waveform_path] - waveform_set = WaveformPolarizationSet.read(self.training_waveform_files[0]) + waveform_set = WaveformPolarizationSet.read( + self.training_waveform_files[0] + ) if waveform_set.right_pad != self.right_pad: raise ValueError( "Training waveform file does not have the same " @@ -140,7 +144,9 @@ def __init__( "without using chunked storage. This can have " "severe performance impacts at data loading time. " "If you need faster loading, try re-generating " - "your datset with chunked storage turned on.".format(fnames), + "your datset with chunked storage turned on.".format( + fnames + ), stacklevel=2, ) @@ -180,12 +186,17 @@ def load_chunk(self, fname, start, size): channel: self.mmap_datasets[fname][channel][start:end] for channel in self.channels } - params = {k: self.param_datasets[fname][k][start:end] for k in self.param_keys} + params = { + k: self.param_datasets[fname][k][start:end] + for k in self.param_keys + } return waveforms, params def sample_batch(self): # allocate batch up front - batch = np.zeros((self.batch_size, self.num_channels, self.waveform_size)) + batch = np.zeros( + (self.batch_size, self.num_channels, self.waveform_size) + ) params_buf = { k: np.zeros( self.batch_size, @@ -197,7 +208,9 @@ def sample_batch(self): for i in range(self.chunks_per_batch): fname = np.random.choice(self.fnames, p=self.probs) - chunk_size = min(self.chunk_size, self.batch_size - i * self.chunk_size) + chunk_size = min( + self.chunk_size, self.batch_size - i * self.chunk_size + ) # select a random starting index for the chunk max_start = self.sizes[fname] - chunk_size @@ -264,7 +277,9 @@ def __iter__(self): def _next(it): [waveform_chunk], param_dict_chunk = next(it) - return waveform_chunk, {k: v[0] for k, v in param_dict_chunk.items()} + return waveform_chunk, { + k: v[0] for k, v in param_dict_chunk.items() + } waveform_chunk, param_chunk = _next(it) num_waveforms, _, _ = waveform_chunk.shape From a1fe2123762f35c0f7823c666fa900b34698fbcf Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Fri, 26 Jun 2026 14:29:34 -0400 Subject: [PATCH 14/16] Add Gaussian NLL regression model and config GaussianNLLRegressionAframe predicts a mean and variance per parameter and trains with a beta-weighted Gaussian NLL (BetaNLLLoss); detection score is the negative mean variance. Includes an example S4D config. --- .../configs/regression/s4d_gaussian_nll.yaml | 126 ++++++++++++++++++ projects/train/train/model/regression.py | 121 ++++++++++++++++- projects/train/train/utils/__init__.py | 0 projects/train/train/utils/beta_nll_loss.py | 30 +++++ 4 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 projects/train/configs/regression/s4d_gaussian_nll.yaml create mode 100644 projects/train/train/utils/__init__.py create mode 100644 projects/train/train/utils/beta_nll_loss.py diff --git a/projects/train/configs/regression/s4d_gaussian_nll.yaml b/projects/train/configs/regression/s4d_gaussian_nll.yaml new file mode 100644 index 00000000..9e11a13c --- /dev/null +++ b/projects/train/configs/regression/s4d_gaussian_nll.yaml @@ -0,0 +1,126 @@ +# S4D chirp-mass regression with a Gaussian (beta-)NLL loss. +# +# Predicts chirp_mass AND its uncertainty from whitened strain, using +# GaussianNLLRegressionAframe (mean + variance head, BetaNLLLoss) on top of an +# S4D sequence model. The detection score is the negative mean predicted +# variance (low uncertainty -> signal-like). +# +# NOTE: the architecture's d_output must be 2 * len(param_names) (one mean + +# one variance per parameter). + +model: + class_path: train.model.regression.GaussianNLLRegressionAframe + init_args: + arch: + class_path: architectures.SupervisedS4Model + init_args: + num_ifos: 2 + d_output: 2 # 2 * len(param_names): chirp_mass mean + variance + d_model: 64 + d_state: 64 + n_layers: 4 + dropout: 0.2 + + param_names: + - chirp_mass + + beta_nll: 0.3 + y_mean: [1.2] # center of chirp_mass prior + y_std: [0.39] # std of Uniform(0.87, 2.22) + + learning_rate: 1e-3 + pct_lr_ramp: 0.115 + weight_decay: 1e-2 + +data: + class_path: train.data.supervised.TimeDomainSupervisedAframeDataset + init_args: + background_dir: data/bns/background + waveforms_dir: data/bns/train + ifos: [H1, L1] + sample_rate: 2048 + + kernel_length: 4.0 + left_pad: 3.5 # merger floats 3.5-4.0s from the left edge + right_pad: 0.0 + + batch_size: 16 + batches_per_epoch: 625 + chunk_size: 2000 + chunks_per_epoch: 10 + num_files_per_batch: 1 + + fduration: 1.0 + fftlength: 2.0 + psd_length: 20.0 + highpass: 20 + + waveform_prob: 1.0 + max_num_workers: 8 + + valid_stride: 0.5 + num_valid_views: 1 # regression: multi-view averaging is not useful + valid_livetime: 57600 + + snr_sampler: + class_path: train.augmentations.SnrSampler + init_args: + alpha: -3 + max_min_snr: 50.0 + min_min_snr: 8.0 + max_snr: 50.0 + decay_steps: 50000 + + param_transforms: + - class_path: train.transforms.ChirpMass + + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader + init_args: + ifos: [H1, L1] + sample_rate: 2048 + val_waveform_file: data/bns/val_waveforms.hdf5 + training_waveform_path: data/bns/train_waveforms.hdf5 + + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: -3.14159 + high: 3.14159 + validate_args: false + +trainer: + logger: + - class_path: train.callbacks.AframeWandbLogger + init_args: + name: s4d_chirp_mass_gnll_merger_4s + save_dir: logs + project: BNS-PUBLICATION + + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: validation/mae_chirp_mass + mode: min + patience: 50 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: validation/mae_chirp_mass + mode: min + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + + accelerator: auto + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_progress_bar: true diff --git a/projects/train/train/model/regression.py b/projects/train/train/model/regression.py index c531ba48..7bdf36d9 100644 --- a/projects/train/train/model/regression.py +++ b/projects/train/train/model/regression.py @@ -1,11 +1,13 @@ -from typing import List +from typing import List, Optional import warnings import torch import torch.nn.functional as F +from torch import nn from architectures import Architecture from train.model.base import AframeBase +from train.utils.beta_nll_loss import BetaNLLLoss Tensor = torch.Tensor @@ -83,3 +85,120 @@ def validation_step(self, batch, _): on_epoch=True, sync_dist=True, ) + + +class GaussianNLLRegressionAframe(SupervisedRegressionAframe): + """ + Regression model that predicts each parameter's value AND its uncertainty, + trained with a (beta-weighted) Gaussian negative-log-likelihood loss. + + Where ``SupervisedRegressionAframe`` predicts a single value per parameter + and trains with MSE, this model predicts two numbers per parameter: a mean + and a raw variance. The raw variance is passed through ``Softplus`` to keep + it positive. The architecture's forward must therefore return shape + ``(N, 2 * len(param_names))`` with the means in the first half and the raw + variances in the second half. + + The detection score is the negative mean predicted variance, so a + confident (low-uncertainty) prediction scores high. + + Args: + arch: + Architecture whose forward returns ``(N, 2 * len(param_names))``. + param_names: + Ordered parameter names to regress on. + beta_nll: + beta for the beta-NLL loss (Seitzer et al. 2022). ``0`` gives plain + Gaussian NLL, ``0.5`` is the recommended default, ``1`` recovers an + MSE-like gradient. + y_mean, y_std: + Optional per-parameter normalization applied to the targets (and + inverted on the predictions when reporting). Length + ``len(param_names)``; default to zero-mean / unit-std (no-op). + """ + + def __init__( + self, + arch: Architecture, + param_names: List[str], + beta_nll: float = 0.5, + y_mean: Optional[List[float]] = None, + y_std: Optional[List[float]] = None, + **kwargs, + ): + super().__init__(arch=arch, param_names=param_names, **kwargs) + self.n_vars = len(param_names) + self.var_activation = nn.Softplus() + self.criterion = BetaNLLLoss(beta=beta_nll) + + _y_mean = ( + torch.tensor(y_mean, dtype=torch.float32) + if y_mean is not None + else torch.zeros(self.n_vars) + ) + _y_std = ( + torch.tensor(y_std, dtype=torch.float32) + if y_std is not None + else torch.ones(self.n_vars) + ) + self.register_buffer("y_mean", _y_mean) + self.register_buffer("y_std", _y_std) + + def _split(self, outputs: Tensor) -> tuple[Tensor, Tensor]: + """Split the network output into (mean, positive variance).""" + mean = outputs[:, : self.n_vars] + var = self.var_activation(outputs[:, self.n_vars :]) + return mean, var + + def _normalize(self, y: Tensor) -> Tensor: + return (y - self.y_mean) / self.y_std + + def score(self, X: Tensor) -> Tensor: + # detection score: lower predicted uncertainty -> higher score + _, var = self._split(self(X)) + return -var.mean(dim=-1) + + def train_step(self, batch): + X, y, params = batch + mask = ~torch.isnan(next(iter(params.values()))) + + if not mask.any(): + warnings.warn( + "All samples in batch have NaN parameters;" + "skipping regression step.", + stacklevel=2, + ) + return torch.zeros(1, device=X.device, requires_grad=True) + + targets = torch.stack( + [params[k][mask] for k in self.param_names], dim=1 + ) + mean, var = self._split(self(X[mask])) + return self.criterion(mean, self._normalize(targets), var) + + def validation_step(self, batch, _): + _, _, X_inj, params = batch + num_views, N, *shape = X_inj.shape + X_inj = X_inj.view(num_views * N, *shape) + + mean, var = self._split(self(X_inj)) + # convert back to physical units for reporting + mean_phys = mean * self.y_std + self.y_mean + sigma_phys = torch.sqrt(var) * self.y_std + + for i, name in enumerate(self.param_names): + targets = params[name].repeat(num_views) + self.log( + f"validation/mae_{name}", + F.l1_loss(mean_phys[:, i], targets), + on_step=False, + on_epoch=True, + sync_dist=True, + ) + self.log( + f"validation/sigma_{name}", + sigma_phys[:, i].mean(), + on_step=False, + on_epoch=True, + sync_dist=True, + ) diff --git a/projects/train/train/utils/__init__.py b/projects/train/train/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/train/train/utils/beta_nll_loss.py b/projects/train/train/utils/beta_nll_loss.py new file mode 100644 index 00000000..606f04c4 --- /dev/null +++ b/projects/train/train/utils/beta_nll_loss.py @@ -0,0 +1,30 @@ +from torch import nn +import torch + + +class BetaNLLLoss(nn.Module): + """β-NLL loss (Seitzer et al. 2022, https://arxiv.org/abs/2203.09168). + + Standard GaussianNLL has a degenerate minimum where inflating var drives + the mean gradient to zero. β-weighting prevents this: + + L_β = sg(var)^β · L_NLL + ∂L_β/∂mean = sg(var)^(β−1) · (mean − y) + + β=0 → standard NLL (degenerate); β=0.5 → recommended; β=1 → MSE gradient. + """ + + def __init__(self, beta: float = 0.5, reduction: str = "mean"): + super().__init__() + if not 0.0 <= beta <= 1.0: + raise ValueError(f"beta must be in [0, 1], got {beta}") + self.beta = beta + self.reduction = reduction + + def forward( + self, mean: torch.Tensor, target: torch.Tensor, var: torch.Tensor + ) -> torch.Tensor: + nll = 0.5 * (torch.log(var) + (mean - target) ** 2 / var) + if self.beta > 0.0: + nll = nll * var.detach().pow(self.beta) + return nll.mean() if self.reduction == "mean" else nll.sum() From af8e629218193a51c26d90f0d1d92825314dc0a1 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Fri, 26 Jun 2026 14:29:38 -0400 Subject: [PATCH 15/16] Fix SupervisedS4Model for ml4gw 0.8.3 API Drop the removed length/prenorm/lr arguments and pass d_state. --- libs/architectures/architectures/supervised.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index ae336b38..36b92c44 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -146,29 +146,23 @@ class SupervisedS4Model(S4Model, SupervisedArchitecture): def __init__( self, num_ifos: int, - sample_rate: float, - kernel_length: float, d_output: int = 1, d_model: int = 128, + d_state: int = 64, n_layers: int = 4, dropout: float = 0.1, - prenorm: bool = True, dt_min: float = 0.001, dt_max: float = 0.1, - lr: Optional[float] = None, ) -> None: - length = int(kernel_length * sample_rate) super().__init__( - length=length, d_input=num_ifos, d_output=d_output, d_model=d_model, + d_state=d_state, n_layers=n_layers, dropout=dropout, - prenorm=prenorm, dt_min=dt_min, dt_max=dt_max, - lr=lr, ) From e217dfe5857c25de71124aec198a0ea84ae27dc7 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Fri, 26 Jun 2026 14:29:41 -0400 Subject: [PATCH 16/16] OOM-safe lazy waveform loading Read metadata and per-rank slices via h5py/_load_with_idx instead of loading whole ledgers into memory (WaveformLoader init and WaveformSampler val), add a num_val_waveforms cap, and update the sampler test mock accordingly. --- .../train/tests/data/test_waveform_sampler.py | 2 +- projects/train/train/data/waveforms/loader.py | 10 +++---- .../train/train/data/waveforms/sampler.py | 27 ++++++++++++++----- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/projects/train/tests/data/test_waveform_sampler.py b/projects/train/tests/data/test_waveform_sampler.py index 670c3e62..e6b37f2c 100644 --- a/projects/train/tests/data/test_waveform_sampler.py +++ b/projects/train/tests/data/test_waveform_sampler.py @@ -86,7 +86,7 @@ def test_non_ndarray_params_skipped(tmp_path): setattr(real_ws, target_field, 42.0) mock_cls = MagicMock() - mock_cls.read.return_value = real_ws + mock_cls._load_with_idx.return_value = real_ws with patch.object( type(sampler), diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index c0f3f40c..9db05c5d 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -9,7 +9,6 @@ import numpy as np import torch -from ledger.injections import WaveformPolarizationSet from .sampler import WaveformSampler @@ -37,10 +36,11 @@ def __init__( else: self.training_waveform_files = [training_waveform_path] - waveform_set = WaveformPolarizationSet.read( - self.training_waveform_files[0] - ) - if waveform_set.right_pad != self.right_pad: + # Read only the right_pad attribute; reading the full ledger would + # load the entire (potentially hundreds of GB) waveform file into RAM. + with h5py.File(self.training_waveform_files[0], "r") as f: + file_right_pad = f.attrs["right_pad"] + if file_right_pad != self.right_pad: raise ValueError( "Training waveform file does not have the same " "right pad as validation waveform file" diff --git a/projects/train/train/data/waveforms/sampler.py b/projects/train/train/data/waveforms/sampler.py index 02e5af73..48db4003 100644 --- a/projects/train/train/data/waveforms/sampler.py +++ b/projects/train/train/data/waveforms/sampler.py @@ -1,7 +1,8 @@ from dataclasses import fields from pathlib import Path -from typing import List +from typing import List, Optional +import h5py import numpy as np import torch from utils import x_per_y @@ -31,6 +32,7 @@ def __init__( ifos: List[str], sample_rate: float, val_waveform_file: Path, + num_val_waveforms: Optional[int] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) @@ -38,9 +40,17 @@ def __init__( self.sample_rate = sample_rate self.val_waveform_file = val_waveform_file - waveform_set = self.waveform_set_cls.read(val_waveform_file) - self.num_val_waveforms = len(waveform_set) - self.right_pad = waveform_set.right_pad + # Read only metadata; reading the full ledger would load the entire + # (potentially many-GB) validation file into memory. num_val_waveforms + # optionally caps how many validation waveforms are used. + with h5py.File(val_waveform_file, "r") as f: + total = int(f.attrs["num_injections"]) + self.right_pad = float(f.attrs["right_pad"]) + self.num_val_waveforms = ( + total + if num_val_waveforms is None + else min(int(num_val_waveforms), total) + ) @property def waveform_set_cls(self): @@ -78,8 +88,11 @@ def get_val_waveforms( start, stop = self.get_slice_bounds( self.num_val_waveforms, world_size, rank ) - waveform_set = self.waveform_set_cls.read(self.val_waveform_file) - waveforms = torch.Tensor(waveform_set.waveforms[start:stop]) + # Load only the [start:stop] slice rather than the whole file. + idx = np.arange(start, stop) + with h5py.File(self.val_waveform_file, "r") as h5f: + waveform_set = self.waveform_set_cls._load_with_idx(h5f, idx) + waveforms = torch.Tensor(waveform_set.waveforms) params: dict[str, torch.Tensor] = {} for f in fields(waveform_set): @@ -87,7 +100,7 @@ def get_val_waveforms( continue val = getattr(waveform_set, f.name) if isinstance(val, np.ndarray): - params[f.name] = torch.from_numpy(val[start:stop]) + params[f.name] = torch.from_numpy(val) return waveforms, params