From b7f208a73b51aef1029c7d9d31eae8c4c4cc8e37 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 13:20:37 -0400 Subject: [PATCH 1/7] updating ml4gw version to 0.8.3 --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index b3700fae..5e2e711f 100644 --- a/uv.lock +++ b/uv.lock @@ -1911,7 +1911,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1920,9 +1920,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] From 1bf718db17628d8896d56242cd7aed824470467a Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 17:37:09 -0400 Subject: [PATCH 2/7] updating to ml4gw==0.8.3 in all the uv envs --- libs/architectures/uv.lock | 45 +++++++++++++++++++++++++++++++++++--- libs/ledger/uv.lock | 8 +++---- libs/p_astro/uv.lock | 8 +++---- libs/priors/uv.lock | 10 +++++---- libs/utils/uv.lock | 6 ++--- projects/data/uv.lock | 6 ++--- projects/export/uv.lock | 6 ++--- projects/infer/uv.lock | 8 +++---- projects/online/uv.lock | 6 ++--- projects/plots/uv.lock | 10 +++++---- projects/train/uv.lock | 6 ++--- 11 files changed, 81 insertions(+), 38 deletions(-) diff --git a/libs/architectures/uv.lock b/libs/architectures/uv.lock index 0d862327..82c5429a 100644 --- a/libs/architectures/uv.lock +++ b/libs/architectures/uv.lock @@ -174,17 +174,18 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.2" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, { name = "numpy" }, + { name = "scipy" }, { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/18/3d/9c58325b0f3d606cebc0e52e8b8600aeeca699096184a5ca0c4a2a2b3024/ml4gw-0.7.2.tar.gz", hash = "sha256:fc9f61fbc6e2fd9ae6b8654d4e0468788b56cf8631f1e4d8b4a55dba80931a90", size = 101166, upload-time = "2025-02-12T22:06:36.278Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/2d/efb9febf79d4b54392f8b45910fed080d78ea2ba377319e34919f9cda1cc/ml4gw-0.7.2-py3-none-any.whl", hash = "sha256:a8f2f508420eba1c6dc045762b1c6452ce51b43ffa3c138e2ca4bd8436540ff5", size = 123216, upload-time = "2025-02-12T22:06:35.171Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -390,6 +391,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083, upload-time = "2024-12-01T12:54:19.735Z" }, ] +[[package]] +name = "scipy" +version = "1.15.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/37/6964b830433e654ec7485e45a00fc9a27cf868d622838f6b6d9c5ec0d532/scipy-1.15.3.tar.gz", hash = "sha256:eae3cf522bc7df64b42cad3925c876e1b0b6c35c1337c93e12c0f366f55b0eaf", size = 59419214, upload-time = "2025-05-08T16:13:05.955Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/2f/4966032c5f8cc7e6a60f1b2e0ad686293b9474b65246b0c642e3ef3badd0/scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c", size = 38702770, upload-time = "2025-05-08T16:04:20.849Z" }, + { url = "https://files.pythonhosted.org/packages/a0/6e/0c3bf90fae0e910c274db43304ebe25a6b391327f3f10b5dcc638c090795/scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253", size = 30094511, upload-time = "2025-05-08T16:04:27.103Z" }, + { url = "https://files.pythonhosted.org/packages/ea/b1/4deb37252311c1acff7f101f6453f0440794f51b6eacb1aad4459a134081/scipy-1.15.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:aef683a9ae6eb00728a542b796f52a5477b78252edede72b8327a886ab63293f", size = 22368151, upload-time = "2025-05-08T16:04:31.731Z" }, + { url = "https://files.pythonhosted.org/packages/38/7d/f457626e3cd3c29b3a49ca115a304cebb8cc6f31b04678f03b216899d3c6/scipy-1.15.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:1c832e1bd78dea67d5c16f786681b28dd695a8cb1fb90af2e27580d3d0967e92", size = 25121732, upload-time = "2025-05-08T16:04:36.596Z" }, + { url = "https://files.pythonhosted.org/packages/db/0a/92b1de4a7adc7a15dcf5bddc6e191f6f29ee663b30511ce20467ef9b82e4/scipy-1.15.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:263961f658ce2165bbd7b99fa5135195c3a12d9bef045345016b8b50c315cb82", size = 35547617, upload-time = "2025-05-08T16:04:43.546Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/41991e503e51fc1134502694c5fa7a1671501a17ffa12716a4a9151af3df/scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2abc762b0811e09a0d3258abee2d98e0c703eee49464ce0069590846f31d40", size = 37662964, upload-time = "2025-05-08T16:04:49.431Z" }, + { url = "https://files.pythonhosted.org/packages/25/e1/3df8f83cb15f3500478c889be8fb18700813b95e9e087328230b98d547ff/scipy-1.15.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed7284b21a7a0c8f1b6e5977ac05396c0d008b89e05498c8b7e8f4a1423bba0e", size = 37238749, upload-time = "2025-05-08T16:04:55.215Z" }, + { url = "https://files.pythonhosted.org/packages/93/3e/b3257cf446f2a3533ed7809757039016b74cd6f38271de91682aa844cfc5/scipy-1.15.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5380741e53df2c566f4d234b100a484b420af85deb39ea35a1cc1be84ff53a5c", size = 40022383, upload-time = "2025-05-08T16:05:01.914Z" }, + { url = "https://files.pythonhosted.org/packages/d1/84/55bc4881973d3f79b479a5a2e2df61c8c9a04fcb986a213ac9c02cfb659b/scipy-1.15.3-cp310-cp310-win_amd64.whl", hash = "sha256:9d61e97b186a57350f6d6fd72640f9e99d5a4a2b8fbf4b9ee9a841eab327dc13", size = 41259201, upload-time = "2025-05-08T16:05:08.166Z" }, + { url = "https://files.pythonhosted.org/packages/96/ab/5cc9f80f28f6a7dff646c5756e559823614a42b1939d86dd0ed550470210/scipy-1.15.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:993439ce220d25e3696d1b23b233dd010169b62f6456488567e830654ee37a6b", size = 38714255, upload-time = "2025-05-08T16:05:14.596Z" }, + { url = "https://files.pythonhosted.org/packages/4a/4a/66ba30abe5ad1a3ad15bfb0b59d22174012e8056ff448cb1644deccbfed2/scipy-1.15.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:34716e281f181a02341ddeaad584205bd2fd3c242063bd3423d61ac259ca7eba", size = 30111035, upload-time = "2025-05-08T16:05:20.152Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/a7e5b95afd80d24313307f03624acc65801846fa75599034f8ceb9e2cbf6/scipy-1.15.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b0334816afb8b91dab859281b1b9786934392aa3d527cd847e41bb6f45bee65", size = 22384499, upload-time = "2025-05-08T16:05:24.494Z" }, + { url = "https://files.pythonhosted.org/packages/17/99/f3aaddccf3588bb4aea70ba35328c204cadd89517a1612ecfda5b2dd9d7a/scipy-1.15.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6db907c7368e3092e24919b5e31c76998b0ce1684d51a90943cb0ed1b4ffd6c1", size = 25152602, upload-time = "2025-05-08T16:05:29.313Z" }, + { url = "https://files.pythonhosted.org/packages/56/c5/1032cdb565f146109212153339f9cb8b993701e9fe56b1c97699eee12586/scipy-1.15.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:721d6b4ef5dc82ca8968c25b111e307083d7ca9091bc38163fb89243e85e3889", size = 35503415, upload-time = "2025-05-08T16:05:34.699Z" }, + { url = "https://files.pythonhosted.org/packages/bd/37/89f19c8c05505d0601ed5650156e50eb881ae3918786c8fd7262b4ee66d3/scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39cb9c62e471b1bb3750066ecc3a3f3052b37751c7c3dfd0fd7e48900ed52982", size = 37652622, upload-time = "2025-05-08T16:05:40.762Z" }, + { url = "https://files.pythonhosted.org/packages/7e/31/be59513aa9695519b18e1851bb9e487de66f2d31f835201f1b42f5d4d475/scipy-1.15.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:795c46999bae845966368a3c013e0e00947932d68e235702b5c3f6ea799aa8c9", size = 37244796, upload-time = "2025-05-08T16:05:48.119Z" }, + { url = "https://files.pythonhosted.org/packages/10/c0/4f5f3eeccc235632aab79b27a74a9130c6c35df358129f7ac8b29f562ac7/scipy-1.15.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:18aaacb735ab38b38db42cb01f6b92a2d0d4b6aabefeb07f02849e47f8fb3594", size = 40047684, upload-time = "2025-05-08T16:05:54.22Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a7/0ddaf514ce8a8714f6ed243a2b391b41dbb65251affe21ee3077ec45ea9a/scipy-1.15.3-cp311-cp311-win_amd64.whl", hash = "sha256:ae48a786a28412d744c62fd7816a4118ef97e5be0bee968ce8f0a2fba7acf3bb", size = 41246504, upload-time = "2025-05-08T16:06:00.437Z" }, + { url = "https://files.pythonhosted.org/packages/37/4b/683aa044c4162e10ed7a7ea30527f2cbd92e6999c10a8ed8edb253836e9c/scipy-1.15.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac6310fdbfb7aa6612408bd2f07295bcbd3fda00d2d702178434751fe48e019", size = 38766735, upload-time = "2025-05-08T16:06:06.471Z" }, + { url = "https://files.pythonhosted.org/packages/7b/7e/f30be3d03de07f25dc0ec926d1681fed5c732d759ac8f51079708c79e680/scipy-1.15.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:185cd3d6d05ca4b44a8f1595af87f9c372bb6acf9c808e99aa3e9aa03bd98cf6", size = 30173284, upload-time = "2025-05-08T16:06:11.686Z" }, + { url = "https://files.pythonhosted.org/packages/07/9c/0ddb0d0abdabe0d181c1793db51f02cd59e4901da6f9f7848e1f96759f0d/scipy-1.15.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:05dc6abcd105e1a29f95eada46d4a3f251743cfd7d3ae8ddb4088047f24ea477", size = 22446958, upload-time = "2025-05-08T16:06:15.97Z" }, + { url = "https://files.pythonhosted.org/packages/af/43/0bce905a965f36c58ff80d8bea33f1f9351b05fad4beaad4eae34699b7a1/scipy-1.15.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:06efcba926324df1696931a57a176c80848ccd67ce6ad020c810736bfd58eb1c", size = 25242454, upload-time = "2025-05-08T16:06:20.394Z" }, + { url = "https://files.pythonhosted.org/packages/56/30/a6f08f84ee5b7b28b4c597aca4cbe545535c39fe911845a96414700b64ba/scipy-1.15.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05045d8b9bfd807ee1b9f38761993297b10b245f012b11b13b91ba8945f7e45", size = 35210199, upload-time = "2025-05-08T16:06:26.159Z" }, + { url = "https://files.pythonhosted.org/packages/0b/1f/03f52c282437a168ee2c7c14a1a0d0781a9a4a8962d84ac05c06b4c5b555/scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271e3713e645149ea5ea3e97b57fdab61ce61333f97cfae392c28ba786f9bb49", size = 37309455, upload-time = "2025-05-08T16:06:32.778Z" }, + { url = "https://files.pythonhosted.org/packages/89/b1/fbb53137f42c4bf630b1ffdfc2151a62d1d1b903b249f030d2b1c0280af8/scipy-1.15.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6cfd56fc1a8e53f6e89ba3a7a7251f7396412d655bca2aa5611c8ec9a6784a1e", size = 36885140, upload-time = "2025-05-08T16:06:39.249Z" }, + { url = "https://files.pythonhosted.org/packages/2e/2e/025e39e339f5090df1ff266d021892694dbb7e63568edcfe43f892fa381d/scipy-1.15.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0ff17c0bb1cb32952c09217d8d1eed9b53d1463e5f1dd6052c7857f83127d539", size = 39710549, upload-time = "2025-05-08T16:06:45.729Z" }, + { url = "https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed", size = 40966184, upload-time = "2025-05-08T16:06:52.623Z" }, +] + [[package]] name = "setuptools" version = "75.8.0" diff --git a/libs/ledger/uv.lock b/libs/ledger/uv.lock index 6182d8e9..4048a55d 100644 --- a/libs/ledger/uv.lock +++ b/libs/ledger/uv.lock @@ -1612,7 +1612,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1621,9 +1621,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -2963,7 +2963,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/p_astro/uv.lock b/libs/p_astro/uv.lock index 363c3cd0..37b9bf49 100644 --- a/libs/p_astro/uv.lock +++ b/libs/p_astro/uv.lock @@ -1606,7 +1606,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1615,9 +1615,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -2989,7 +2989,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/priors/uv.lock b/libs/priors/uv.lock index 90545af6..4468a5d2 100644 --- a/libs/priors/uv.lock +++ b/libs/priors/uv.lock @@ -255,6 +255,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -705,7 +707,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -714,9 +716,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -1468,7 +1470,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index 1cfd3e70..d026e015 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -456,7 +456,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -465,9 +465,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/data/uv.lock b/projects/data/uv.lock index 4008a1c0..75432640 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -1496,7 +1496,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1505,9 +1505,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/export/uv.lock b/projects/export/uv.lock index a57e9dee..33ffc0d8 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -972,7 +972,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -981,9 +981,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/infer/uv.lock b/projects/infer/uv.lock index 5502aeeb..4d09280d 100644 --- a/projects/infer/uv.lock +++ b/projects/infer/uv.lock @@ -1897,7 +1897,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1906,9 +1906,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -3510,7 +3510,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/uv.lock b/projects/online/uv.lock index 6687b56a..d723cc8d 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -2302,7 +2302,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2311,9 +2311,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] diff --git a/projects/plots/uv.lock b/projects/plots/uv.lock index d7ba6019..edfdb894 100644 --- a/projects/plots/uv.lock +++ b/projects/plots/uv.lock @@ -425,6 +425,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -2113,7 +2115,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2122,9 +2124,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] @@ -3889,7 +3891,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 27a12732..b54b462d 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -2321,7 +2321,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.8.0" +version = "0.8.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2330,9 +2330,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/1b/78a1d86e3253e3e8626b8079152caf92fb502c58ca8942d293426ad71139/ml4gw-0.8.3.tar.gz", hash = "sha256:d34aadd5d977498c3ac8922664a33874b1e5f5a29079033f0345683f7f9d1868", size = 124777, upload-time = "2026-06-24T17:16:52.947Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/3e96be1b039d7476e5ce32b701df31dc33651b53db4230e9ac38fbde003d/ml4gw-0.8.3-py3-none-any.whl", hash = "sha256:4601d2034b19b4e485c7f71a97983ea2dae9d407a8e2c126c1164c337c15462b", size = 136438, upload-time = "2026-06-24T17:16:51.813Z" }, ] [[package]] From 1bfde0e57b4ff765687662d1fe6c7adb30f76303 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 17:44:49 -0400 Subject: [PATCH 3/7] removing old s4 code --- .../architectures/networks/__init__.py | 1 - .../architectures/networks/s4.py | 270 ------------------ 2 files changed, 271 deletions(-) delete mode 100644 libs/architectures/architectures/networks/s4.py diff --git a/libs/architectures/architectures/networks/__init__.py b/libs/architectures/architectures/networks/__init__.py index 62e5e0f3..f638cee6 100644 --- a/libs/architectures/architectures/networks/__init__.py +++ b/libs/architectures/architectures/networks/__init__.py @@ -1,3 +1,2 @@ -from .s4 import S4Model from .wavenet import WaveNet from .xylophone import Xylophone diff --git a/libs/architectures/architectures/networks/s4.py b/libs/architectures/architectures/networks/s4.py deleted file mode 100644 index b295c6de..00000000 --- a/libs/architectures/architectures/networks/s4.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Copied from https://github.com/chreissel/s4/blob/main/models/s4/s4d.py -and https://github.com/chreissel/s4/blob/main/src/models/nn/dropout.py - -Minimal version of S4D with extra options and features stripped out, -for pedagogical purposes. -""" - -import math -from typing import Optional - -import torch -import torch.nn as nn -from einops import rearrange, repeat - - -class DropoutNd(nn.Module): - def __init__(self, p: float = 0.5, tie=True, transposed=True): - """ - tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) - """ - super().__init__() - if p < 0 or p >= 1: - raise ValueError( - "dropout probability has to be in [0, 1), but got {}".format(p) - ) - self.p = p - self.tie = tie - self.transposed = transposed - self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) - - def forward(self, X): - """X: (batch, dim, lengths...).""" - if self.training: - if not self.transposed: - X = rearrange(X, "b ... d -> b d ...") - mask_shape = ( - X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape - ) - # mask = self.binomial.sample(mask_shape) - mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p - X = X * mask * (1.0 / (1 - self.p)) - if not self.transposed: - X = rearrange(X, "b d ... -> b ... d") - return X - return X - - -class S4DKernel(nn.Module): - """Generate convolution kernel from diagonal SSM parameters.""" - - def __init__( - self, - d_model: int, - length: int, - N: int = 64, - dt_min: float = 0.001, - dt_max: float = 0.1, - lr: float = None, - ): - super().__init__() - - # generate dt - H = d_model - log_dt = torch.rand(H) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - - C = torch.randn(H, N // 2, dtype=torch.cfloat) - self.C = nn.Parameter(torch.view_as_real(C)) - self.register("log_dt", log_dt, lr) - - log_A_real = torch.log(0.5 * torch.ones(H, N // 2)) - A_imag = math.pi * repeat(torch.arange(N // 2), "n -> h n", h=H) - self.register("log_A_real", log_A_real, lr) - self.register("A_imag", A_imag, lr) - - Ls = torch.arange(length) - self.register_buffer("length", Ls) - - def forward(self): - """ - returns: (..., c, L) where c is number of channels (default 1) - """ - - # Materialize parameters - dt = torch.exp(self.log_dt) # (H) - C = torch.view_as_complex(self.C) # (H N) - A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) - - # Vandermonde multiplication - dtA = A * dt.unsqueeze(-1) # (H N) - K = dtA.unsqueeze(-1) * self.length # (H N L) - C = C * (torch.exp(dtA) - 1.0) / A - K = 2 * torch.einsum("hn, hnl -> hl", C, torch.exp(K)).real - - return K - - def register(self, name, tensor, lr=None): - """ - Register a tensor with a configurable learning rate - and 0 weight decay - """ - - if lr == 0.0: - self.register_buffer(name, tensor) - else: - self.register_parameter(name, nn.Parameter(tensor)) - - optim = {"weight_decay": 0.0} - if lr is not None: - optim["lr"] = lr - getattr(self, name)._optim = optim - - -class S4D(nn.Module): - def __init__( - self, - d_model: int, - length: int, - d_state: int = 64, - dropout: float = 0.0, - transposed: bool = True, - dt_min: float = 0.001, - dt_max: float = 0.1, - lr: Optional[float] = None, - ): - super().__init__() - self.transposed = transposed - self.D = nn.Parameter(torch.randn(d_model)) - self.length = length - - # SSM Kernel - self.kernel = S4DKernel( - d_model, - length=length, - N=d_state, - dt_min=dt_min, - dt_max=dt_max, - lr=lr, - ) - - # Pointwise - self.activation = nn.GELU() - # TODO: investigate torch dropout implementation - self.dropout = torch.nn.Dropout1d(dropout) - # self.dropout = DropoutNd(dropout) if dropout > 0.0 else nn.Identity() - - # position-wise output transform to mix features - self.output_linear = nn.Sequential( - nn.Conv1d(d_model, 2 * d_model, kernel_size=1), - nn.GLU(dim=-2), - ) - - def forward(self, u): - """Input and output shape (B, H, L)""" - if not self.transposed: - u = u.transpose(-1, -2) - - # Compute SSM Kernel - k = self.kernel() # (H L) - - # Convolution - k_f = torch.fft.rfft(k, n=2 * self.length) # (H L) - u_f = torch.fft.rfft(u, n=2 * self.length) # (B H L) - y = torch.fft.irfft(u_f * k_f, n=2 * self.length)[ - ..., : self.length - ] # (B H L) - - # Compute D term in state space equation - # Essentially a skip connection - y = y + u * self.D.unsqueeze(-1) - - y = self.dropout(self.activation(y)) - y = self.output_linear(y) - if not self.transposed: - y = y.transpose(-1, -2) - # Return a dummy state to satisfy this repo's interface, - # but this can be modified - return y, None - - -class S4Model(nn.Module): - def __init__( - self, - d_input: int, - length: int, - d_output: int = 10, - d_model: int = 256, - d_state: int = 64, - n_layers: int = 4, - dropout: float = 0.2, - prenorm: bool = False, - dt_min: float = 0.001, - dt_max: float = 0.1, - lr: Optional[float] = None, - ): - super().__init__() - - self.prenorm = prenorm - - # Linear encoder (d_input = 1 for grayscale and 3 for RGB) - self.encoder = nn.Linear(d_input, d_model) - - # Stack S4 layers as residual blocks - self.s4_layers = nn.ModuleList() - self.norms = nn.ModuleList() - self.dropouts = nn.ModuleList() - if lr is not None: - lr = min(0.001, lr) - for _ in range(n_layers): - self.s4_layers.append( - S4D( - length=length, - d_model=d_model, - d_state=d_state, - dropout=dropout, - transposed=True, - dt_min=dt_min, - dt_max=dt_max, - lr=lr, - ) - ) - self.norms.append(nn.LayerNorm(d_model)) - self.dropouts.append(nn.Dropout1d(dropout)) - - # Linear decoder - self.decoder = nn.Linear(d_model, d_output) - - def forward(self, x): - """ - Input x is shape (B, d_input, L) - """ - x = x.transpose(-1, -2) - x = self.encoder(x) # (B, L, d_input) -> (B, L, d_model) - - x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L) - for layer, norm, dropout in zip( - self.s4_layers, self.norms, self.dropouts, strict=True - ): - # Each iteration of this loop will map - # (B, d_model, L) -> (B, d_model, L) - - z = x - if self.prenorm: - # Prenorm - z = norm(z.transpose(-1, -2)).transpose(-1, -2) - - # Apply S4 block: we ignore the state input and output - z, _ = layer(z) - - # Dropout on the output of the S4 block - z = dropout(z) - - # Residual connection - x = z + x - - if not self.prenorm: - # Postnorm - x = norm(x.transpose(-1, -2)).transpose(-1, -2) - - x = x.transpose(-1, -2) - - # Pooling: average pooling over the sequence length - x = x.mean(dim=1) - - # Decode the outputs - x = self.decoder(x) # (B, d_model) -> (B, d_output) - - return x From 05472c978c29af990d55d456013bd159c90e4f93 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 17:58:39 -0400 Subject: [PATCH 4/7] fixed S4Model import --- libs/architectures/architectures/supervised.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 99187183..ae336b38 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -1,10 +1,11 @@ from typing import Literal, Optional from architectures import Architecture -from architectures.networks import S4Model, WaveNet, Xylophone +from architectures.networks import WaveNet, Xylophone from jaxtyping import Float from ml4gw.nn.resnet.resnet_1d import NormLayer, ResNet1D from ml4gw.nn.resnet.resnet_2d import ResNet2D +from ml4gw.nn.ssm.s4d import S4Model from torch import Tensor import torch From dffe12b61116ba44495c1aa379e9d799932b358f Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 18:02:00 -0400 Subject: [PATCH 5/7] ensure ml4gw >= 0.8.3 floor --- libs/architectures/pyproject.toml | 2 +- libs/architectures/uv.lock | 2 +- libs/ledger/uv.lock | 2 +- libs/p_astro/uv.lock | 2 +- libs/priors/uv.lock | 2 +- libs/utils/pyproject.toml | 2 +- libs/utils/uv.lock | 2 +- projects/data/pyproject.toml | 2 +- projects/data/uv.lock | 4 ++-- projects/export/pyproject.toml | 2 +- projects/export/uv.lock | 4 ++-- projects/infer/uv.lock | 2 +- projects/online/pyproject.toml | 2 +- projects/online/uv.lock | 6 +++--- projects/plots/uv.lock | 2 +- projects/train/pyproject.toml | 2 +- projects/train/uv.lock | 6 +++--- 17 files changed, 23 insertions(+), 23 deletions(-) diff --git a/libs/architectures/pyproject.toml b/libs/architectures/pyproject.toml index 7fc7ecab..7606740f 100644 --- a/libs/architectures/pyproject.toml +++ b/libs/architectures/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ "einops>=0.8,<0.9", - "ml4gw>=0.7.2", + "ml4gw>=0.8.3", "h5py>=3.9.0,<4", "numpy~=1.26", ] diff --git a/libs/architectures/uv.lock b/libs/architectures/uv.lock index 82c5429a..012900c2 100644 --- a/libs/architectures/uv.lock +++ b/libs/architectures/uv.lock @@ -26,7 +26,7 @@ dev = [ requires-dist = [ { name = "einops", specifier = ">=0.8,<0.9" }, { name = "h5py", specifier = ">=3.9.0,<4" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "~=1.26" }, ] diff --git a/libs/ledger/uv.lock b/libs/ledger/uv.lock index 4048a55d..09007126 100644 --- a/libs/ledger/uv.lock +++ b/libs/ledger/uv.lock @@ -2963,7 +2963,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/p_astro/uv.lock b/libs/p_astro/uv.lock index 37b9bf49..112c8511 100644 --- a/libs/p_astro/uv.lock +++ b/libs/p_astro/uv.lock @@ -2989,7 +2989,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/priors/uv.lock b/libs/priors/uv.lock index 4468a5d2..025a2df2 100644 --- a/libs/priors/uv.lock +++ b/libs/priors/uv.lock @@ -1470,7 +1470,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/libs/utils/pyproject.toml b/libs/utils/pyproject.toml index 70dbe806..dbc7f2b7 100644 --- a/libs/utils/pyproject.toml +++ b/libs/utils/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "h5py~=3.6", "numpy>=1.26.4,<2", "s3fs>=2024,<2025", - "ml4gw>=0.8.0", + "ml4gw>=0.8.3", "astropy>=6.0.1", ] diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index d026e015..1945698d 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -1124,7 +1124,7 @@ dev = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/data/pyproject.toml b/projects/data/pyproject.toml index 66088d56..08bb09ff 100644 --- a/projects/data/pyproject.toml +++ b/projects/data/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "utils", "ledger", "priors", - "ml4gw>=0.7.2", + "ml4gw>=0.8.3", "aframe", ] diff --git a/projects/data/uv.lock b/projects/data/uv.lock index 75432640..3239e37a 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -740,7 +740,7 @@ dev = [ requires-dist = [ { name = "aframe", editable = "../../" }, { name = "ledger", editable = "../../libs/ledger" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "priors", editable = "../../libs/priors" }, { name = "utils", editable = "../../libs/utils" }, ] @@ -2708,7 +2708,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/export/pyproject.toml b/projects/export/pyproject.toml index b95355d1..c4cf2441 100644 --- a/projects/export/pyproject.toml +++ b/projects/export/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }] requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ - "ml4gw>=0.8.0", + "ml4gw>=0.8.3", "boto3~=1.30", "fsspec[s3]>=2024,<2025", "ml4gw-hermes[torch]>=0.2.1", diff --git a/projects/export/uv.lock b/projects/export/uv.lock index 33ffc0d8..38e35f82 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -528,7 +528,7 @@ requires-dist = [ { name = "boto3", specifier = "~=1.30" }, { name = "fsspec", extras = ["s3"], specifier = ">=2024,<2025" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "ml4gw-hermes", extras = ["torch"], specifier = ">=0.2.1" }, { name = "nvidia-cudnn-cu11", specifier = "==8.9.6.50" }, { name = "tensorrt", specifier = "==8.5.2.2" }, @@ -1980,7 +1980,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/infer/uv.lock b/projects/infer/uv.lock index 4d09280d..1eb05378 100644 --- a/projects/infer/uv.lock +++ b/projects/infer/uv.lock @@ -3510,7 +3510,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/pyproject.toml b/projects/online/pyproject.toml index 0da8d616..f80dbe48 100644 --- a/projects/online/pyproject.toml +++ b/projects/online/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "architectures", "arrakis>=0.2.0,<0.3", "amplfi>=0.5.5", - "ml4gw>=0.7.4", + "ml4gw>=0.8.3", "omegaconf>=2.3.0,<3", "numpy<2.0.0", "scipy<1.15", diff --git a/projects/online/uv.lock b/projects/online/uv.lock index d723cc8d..ac8d0ade 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -214,7 +214,7 @@ dependencies = [ requires-dist = [ { name = "einops", specifier = ">=0.8,<0.9" }, { name = "h5py", specifier = ">=3.9.0,<4" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "~=1.26" }, ] @@ -2814,7 +2814,7 @@ requires-dist = [ { name = "ligo-gracedb", extras = ["kafka"], specifier = ">=2.15.4" }, { name = "ligo-skymap", specifier = ">=2.4.0,<3" }, { name = "matplotlib", specifier = "==3.9.4" }, - { name = "ml4gw", specifier = ">=0.7.4" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "<2.0.0" }, { name = "omegaconf", specifier = ">=2.3.0,<3" }, { name = "p-astro", editable = "../../libs/p_astro" }, @@ -4458,7 +4458,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/plots/uv.lock b/projects/plots/uv.lock index edfdb894..df664e5e 100644 --- a/projects/plots/uv.lock +++ b/projects/plots/uv.lock @@ -3891,7 +3891,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index b9cf9e73..2791acd1 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.8.0", + "ml4gw>=0.8.3", "aframe", "ledger", "priors", diff --git a/projects/train/uv.lock b/projects/train/uv.lock index b54b462d..d5890a5f 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -244,7 +244,7 @@ dependencies = [ requires-dist = [ { name = "einops", specifier = ">=0.8,<0.9" }, { name = "h5py", specifier = ">=3.9.0,<4" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = "~=1.26" }, ] @@ -4373,7 +4373,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, @@ -4513,7 +4513,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] From cfd1c6a1ffc41c596a3446c644e7eab4c87f5d26 Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 18:04:34 -0400 Subject: [PATCH 6/7] pre-commit fix --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 5e2e711f..c066fdc8 100644 --- a/uv.lock +++ b/uv.lock @@ -3688,7 +3688,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.8.0" }, + { name = "ml4gw", specifier = ">=0.8.3" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] From 0ef235a3c4c0dc252a2309edb08b4f443d50133b Mon Sep 17 00:00:00 2001 From: kyoon-mit Date: Wed, 24 Jun 2026 19:51:58 -0400 Subject: [PATCH 7/7] modified s4 lightning module so that it preserves the separate lr behavior for the s4kernel parameters while tracking the modifications in ml4gw.nn.ssm.s4d --- projects/train/train/model/supervised.py | 81 ++++++++++++++---------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 8bad6068..4da68160 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -185,54 +185,67 @@ def validation_step(self, batch, _) -> None: class SupervisedAframeS4(SupervisedAframe): - def __init__(self, arch: SupervisedArchitecture, *args, **kwargs) -> None: + # S4D state-space kernel parameters: trained with a small learning rate + # and no weight decay. These names match the parameters registered by + # ml4gw's S4DKernel; edit this tuple to change which params receive the + # special learning rate. + SSM_PARAM_NAMES = ("log_dt", "log_A_real", "A_imag") + + def __init__( + self, + arch: SupervisedArchitecture, + *args, + ssm_lr: float = 1e-3, + **kwargs, + ) -> None: super().__init__(arch, *args, **kwargs) + self.save_hyperparameters("ssm_lr") def forward(self, X): return self.model(X) def configure_optimizers(self): """ - S4 requires a specific optimizer setup. - - The S4 layer (A, B, C, dt) parameters typically - require a smaller learning rate (typically 0.001), - with no weight decay. + Configure the optimizer and learning-rate scheduler. - The rest of the model can be trained with a higher learning rate - (e.g. 0.004, 0.01) and weight decay (if desired). + Parameters whose names appear in SSM_PARAM_NAMES are placed in their + own optimizer group with learning rate ssm_lr and zero weight decay. + All other parameters use learning_rate (scaled by the distributed + world size) and weight_decay. A cosine-annealing schedule decays both + groups from their base learning rates over the course of training. """ - if not torch.distributed.is_initialized(): - world_size = 1 - else: - world_size = torch.distributed.get_world_size() - - # All parameters in the model - all_parameters = list(self.model.parameters()) - - # General parameters don't contain the special _optim key - params = [p for p in all_parameters if not hasattr(p, "_optim")] - - # Create an optimizer with the general parameters + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) lr = self.hparams.learning_rate * world_size self._logger.info(f"Scaled lr by {world_size} to {lr}") - optimizer = torch.optim.AdamW( - params, lr=lr, weight_decay=self.hparams.weight_decay - ) - # Add parameters with special hyperparameters - hps = [p._optim for p in all_parameters if hasattr(p, "_optim")] - hps = [ - dict(s) - for s in sorted(dict.fromkeys(frozenset(hp.items()) for hp in hps)) - ] # Unique dicts - for hp in hps: - params = [ - p for p in all_parameters if getattr(p, "_optim", None) == hp + ssm_params, other_params = [], [] + for name, p in self.model.named_parameters(): + leaf = name.rsplit(".", 1)[-1] + if leaf in self.SSM_PARAM_NAMES: + ssm_params.append(p) + else: + other_params.append(p) + + optimizer = torch.optim.AdamW( + [ + { + "params": other_params, + "lr": lr, + "weight_decay": self.hparams.weight_decay, + }, + { + "params": ssm_params, + "lr": self.hparams.ssm_lr, + "weight_decay": 0.0, + }, ] - optimizer.add_param_group({"params": params, **hp}) + ) - # Create a lr scheduler + # Decay each group from its own base lr, preserving the lr ratio. scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, self.trainer.estimated_stepping_batches )