From c8c9a86bba75cce6da82a55a09a0519443375696 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Mon, 16 Mar 2026 17:59:20 -0700 Subject: [PATCH 1/6] Reorganize Sequence object --- projects/infer/infer/data.py | 402 ++++++++++++++++++++++++++--------- projects/infer/infer/main.py | 4 +- 2 files changed, 298 insertions(+), 108 deletions(-) diff --git a/projects/infer/infer/data.py b/projects/infer/infer/data.py index 176e273ae..567f1428e 100644 --- a/projects/infer/infer/data.py +++ b/projects/infer/infer/data.py @@ -1,5 +1,7 @@ import logging import math +import re +from abc import ABC, abstractmethod from contextlib import nullcontext from typing import Optional from zlib import adler32 @@ -7,12 +9,145 @@ import h5py import numpy as np from ratelimiter import RateLimiter +from gwpy.timeseries import TimeSeriesDict from ledger.events import EventSet, RecoveredInjectionSet from ledger.injections import InterferometerResponseSet, waveform_class_factory -class Sequence: +FNAME_PATTERNS = { + "prefix": "[a-zA-Z0-9_:-]+", + "start": "[0-9]{10}", + "duration": "[1-9][0-9]*", + "suffix": "(gwf)|(hdf5)|(h5)", +} +FNAME_GROUPS = {k: f"(?P<{k}>{v})" for k, v in FNAME_PATTERNS.items()} +FNAME_PATTERN = "{prefix}-{start}-{duration}.{suffix}".format(**FNAME_GROUPS) +FNAME_RE = re.compile(FNAME_PATTERN) + + +class BaseSequence(ABC): + def __init__( + self, + inference_sampling_rate: float, + batch_size: int, + rate: float | None = None, + **kwargs, + ): + self.inference_sampling_rate = inference_sampling_rate + self.batch_size = batch_size + + # Subclasses provide source-specific setup for metadata and data. + self._setup(**kwargs) + + self.stride = int(self.sample_rate / inference_sampling_rate) + self.step_size = self.stride * batch_size + + # derive some properties from that metadata, + # including come up with a semi-unique sequence + # id derived from a hash of its most descriptive parts + fingerprint = f"{self.t0}{self.duration}{self.shifts}".encode() + self.id = adler32(fingerprint) + self._initialize_sequence_state() + + # rate refers to the average number of requests + # per second, but remember that each yield + # corresponds to two inference requests. Rather + # than splitting the period in half, we'll allow + # two calls during a given period to help account + # for the time required to e.g. serialize the data + # into inference requests + self.limiter = ( + RateLimiter(max_calls=2, period=3.5 / rate) + if rate + else nullcontext() + ) + + def _initialize_sequence_state(self): + self._started = {} + self._done = {} + self._sequences = {} + size = len(self) * self.batch_size + for i in range(2): + seq_id = self.id + i + self._started[seq_id] = False + self._done[seq_id] = False + self._sequences[seq_id] = np.zeros(size) + + def _finalize_sequence(self, has_foreground: bool = True): + if not self.done: + return None + + # if both the background and foreground + # sequences have completed, return them both, + # slicing off the dummy data from the last batch + background = self._sequences[self.id][self.slice] + foreground = None + if has_foreground: + foreground = self._sequences[self.id + 1][self.slice] + return background, foreground + + def _record_response(self, y, request_id, sequence_id): + # insert the response at the appropriate + # spot in the corresponding output array + start = request_id * self.batch_size + stop = (request_id + 1) * self.batch_size + self._sequences[sequence_id][start:stop] = y[:, 0] + + # indicate that the first response for + # this sequence has returned, and possibly + # that the last one has returned as well + self._started[sequence_id] = True + if request_id == len(self) - 1: + self._done[sequence_id] = True + + @abstractmethod + def _setup(self, **kwargs): + """Subclasses must set sample_rate, size, t0, duration, and shifts.""" + pass + + @property + def started(self): + return all(self._started.values()) + + @property + def done(self): + return all(self._done.values()) + + @property + def remainder(self): + # the number of remaining data points not filling a full batch + return (self.size - max(self.shifts)) % self.step_size + + @property + def num_pad(self): + # the number of zeros we need to pad the last batch + # to make it a full batch + return (self.step_size - self.remainder) % self.step_size + + @property + def slice(self) -> slice: + """ + The number of inference requests we need to slice + off the end of the sequences to remove + the dummy data from the last batch + """ + + # if num_pad is 0 don't slice anything + num_slice = self.num_pad // self.stride + end = -num_slice if num_slice else None + return slice(end) + + def __len__(self): + # this include excess data at end of sequence that can't + # be used for a full batch. We'll end up padding it + # with zeros to make it a full batch and + # slicing off the actual useful inference requests + # corresponding to the excess + return math.ceil((self.size - max(self.shifts)) / self.step_size) + + +class Hdf5Sequence(BaseSequence): def __init__( self, background_fname: str, @@ -50,15 +185,33 @@ def __init__( """ logging.info("Initializing sequence") + super().__init__( + inference_sampling_rate=inference_sampling_rate, + batch_size=batch_size, + rate=rate, + background_fname=background_fname, + injection_set_fname=injection_set_fname, + ifos=ifos, + shifts=shifts, + ) + + if self.injection_set is None: + self._done[self.id + 1] = True + self._started[self.id + 1] = True + + def _setup( + self, + background_fname: str, + injection_set_fname: str, + ifos: list[str], + shifts: list[float], + ): self.background_fname = background_fname - self.inference_sampling_rate = inference_sampling_rate - self.batch_size = batch_size - self.rate = rate self.ifos = ifos if len(ifos) != len(shifts): raise ValueError( - "Number of ifos must match number of shifts" + "Number of ifos must match number of shifts; " f"got {len(ifos)} ifos and {len(shifts)} shifts" ) @@ -96,86 +249,9 @@ def __init__( injection_set = None self.injection_set = injection_set - - # derive some properties from that metadata, - # including come up with a semi-unique sequence - # id derived from a hash of its most descriptive parts - fingerprint = f"{self.t0}{self.duration}{shifts}".encode() - self.id = adler32(fingerprint) - self.shifts = [int(i * self.sample_rate) for i in shifts] - self.stride = int(self.sample_rate / inference_sampling_rate) - self.step_size = self.stride * batch_size - # initialize some containers for handling during - # the inference response callback - self._started = {} - self._done = {} - self._sequences = {} - size = len(self) * self.batch_size - for i in range(2): - seq_id = self.id + i - self._started[seq_id] = False - self._done[seq_id] = False - self._sequences[seq_id] = np.zeros(size) - - # if there are no injections, we can mark - # the injection sequence as started and done - if self.injection_set is None: - self._done[self.id + 1] = True - self._started[self.id + 1] = True - - @property - def started(self): - return all(self._started.values()) - - @property - def done(self): - return all(self._done.values()) - - @property - def remainder(self): - # the number of remaining data points not filling a full batch - return (self.size - max(self.shifts)) % self.step_size - - @property - def num_pad(self): - # the number of zeros we need to pad the last batch - # to make it a full batch - return (self.step_size - self.remainder) % self.step_size - - @property - def slice(self) -> slice: - """ - The number of inference requests we need to slice - off the end of the sequences to remove - the dummy data from the last batch - """ - - # if num_pad is 0 don't slice anything - num_slice = self.num_pad // self.stride - end = -num_slice if num_slice else None - return slice(end) - - def __len__(self): - # this include excess data at end of sequence that can't - # be used for a full batch. We'll end up padding it - # with zeros to make it a full batch and - # slicing off the actual useful inference requests - # corresponding to the excess - return math.ceil((self.size - max(self.shifts)) / self.step_size) + self.shifts = np.array([int(i * self.sample_rate) for i in shifts]) def __iter__(self): - if self.rate is not None: - # rate refers to the average number of requests - # per second, but remember that each yield - # corresponds to two inference requests. Rather - # than splitting the period in half, we'll allow - # two calls during a given period to help account - # for the time required to e.g. serialize the data - # into inference requests - limiter = RateLimiter(max_calls=2, period=3.5 / self.rate) - else: - limiter = nullcontext() - with h5py.File(self.background_fname, "r") as f: for i in range(len(self)): # if this is the last batch, we may need to pad it @@ -216,32 +292,146 @@ def __iter__(self): # return the two sets of updates, possibly # rate limited if we specified a max rate - with limiter: + with self.limiter: yield x, x_inj def __call__(self, y, request_id, sequence_id): - # insert the response at the appropriate - # spot in the corresponding output array - start = request_id * self.batch_size - stop = (request_id + 1) * self.batch_size - self._sequences[sequence_id][start:stop] = y[:, 0] - - # indicate that the first response for - # this sequence has returned, and possibly - # that the last one has returned as well - self._started[sequence_id] = True - if request_id == len(self) - 1: - self._done[sequence_id] = True - - # if both the background and foreground - # sequences have completed, return them both, - # slicing off the dummy data from the last batch - if self.done: - background = self._sequences[self.id][self.slice] - foreground = None - if self.injection_set is not None: - foreground = self._sequences[self.id + 1][self.slice] - return background, foreground + self._record_response(y, request_id, sequence_id) + return self._finalize_sequence( + has_foreground=self.injection_set is not None + ) def recover(self, foreground: EventSet) -> RecoveredInjectionSet: return RecoveredInjectionSet.recover(foreground, self.injection_set) + + +class RnPSequence(BaseSequence): + def __init__( + self, + injection_files: list[str], + channel: str, + ifos: list[str], + sample_rate: float, + inference_sampling_rate: float, + batch_size: int, + rate: Optional[float] = None, + ): + """ + Object used for iterating over a segment of data, + performing timeshifts, optionally injecting waveforms, and + aggregating the returned inference outputs. + + If the injection set is empty for this given + segment and shifts, infernece on injections will be skipped, + and `None` will be returned for the foreground events. + + Args: + injection_files: + List of R&P injection files to be analyzed. + channel: + Name of the channel within the frame file + ifos: + Interferometer names + sample_rate: + Sample rate that data will be resampled to for inference + inference_sampling_rate: + Rate at which inference is performed + batch_size: + Number of inference requests to send to the model at once + rate: + Rate at which to send requests in Hz + """ + logging.info("Initializing sequence") + + self.sample_rate = sample_rate + super().__init__( + inference_sampling_rate=inference_sampling_rate, + batch_size=batch_size, + rate=rate, + injection_files=injection_files, + channel=channel, + ifos=ifos, + ) + + def _setup( + self, + injection_files: list[str], + channel: str, + ifos: list[str], + ): + self.ifos = ifos + + # Don't shift timeseries for R&P injections + self.shifts = np.zeros(len(ifos), dtype=int) + self.channels = [f"{ifo}:{channel}" for ifo in ifos] + + if not injection_files: + raise ValueError("Must provide at least one injection file") + + injection_files = sorted(injection_files) + + matches = [FNAME_RE.search(fname) for fname in injection_files] + if not all(matches): + raise ValueError( + "All injection files must match expected name pattern" + ) + + starts = [int(match.group("start")) for match in matches] + durations = [int(match.group("duration")) for match in matches] + + self.t0 = min(starts) + self.duration = sum(durations) + self.size = int(self.duration * self.sample_rate) + self.timeseries = np.zeros((len(ifos), self.size)) + + # Load and resample data from each file + for file, start, duration in zip( + injection_files, starts, durations, strict=True + ): + start_idx = int(self.sample_rate * (start - self.t0)) + end_idx = start_idx + int(self.sample_rate * duration) + injected = TimeSeriesDict.read(file, channels=self.channels) + injected = injected.resample(self.sample_rate) + self.timeseries[:, start_idx:end_idx] = np.stack( + [injected[ch].value for ch in self.channels] + ) + + def __iter__(self): + for i in range(len(self)): + # if this is the last batch, we may need to pad it + # to make it a full batch + last = i == len(self) - 1 + # grab the current batch of updates from the file + # and stack it into a 2D array + start = i * self.step_size + + # for all but last batch just + # increase by step size + end = start + self.step_size + + # if this is the last batch + # and we need to pad it + # just step by the remainder + if last and self.remainder: + end = start + self.remainder + + x_inj = self.timeseries[:, start:end] + + if last: + x_inj = np.pad(x_inj, ((0, 0), (0, self.num_pad)), "constant") + + # TODO: Do we need this type conversion? + x_inj = x_inj.astype(np.float32) + + # yield the same data twice, as we normally + # expect to pass background and foreground + with self.limiter: + yield x_inj, x_inj + + def __call__(self, y, request_id, sequence_id): + self._record_response(y, request_id, sequence_id) + return self._finalize_sequence(has_foreground=True) + + # TODO: format the foreground into whatever R&P expects + def recover(self, foreground: EventSet) -> EventSet: + return foreground diff --git a/projects/infer/infer/main.py b/projects/infer/infer/main.py index 67d9260f5..ce2250b6c 100644 --- a/projects/infer/infer/main.py +++ b/projects/infer/infer/main.py @@ -5,13 +5,13 @@ from hermes.aeriel.client import InferenceClient from tqdm import tqdm -from infer.data import Sequence +from infer.data import BaseSequence from infer.postprocess import Postprocessor def infer( client: InferenceClient, - sequence: Sequence, + sequence: BaseSequence, postprocessor: Postprocessor, return_timeseries: bool = False, ): From 032f7c11d805a4b9dda3b7205f91c36a09ffa20f Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 17 Mar 2026 06:47:51 -0700 Subject: [PATCH 2/6] Add tasks for RnP analysis --- aframe/tasks/infer/__init__.py | 2 +- aframe/tasks/infer/base.py | 227 ++++++++++++++++++++++----------- aframe/tasks/infer/infer.py | 131 +++++++++++++++---- 3 files changed, 262 insertions(+), 98 deletions(-) diff --git a/aframe/tasks/infer/__init__.py b/aframe/tasks/infer/__init__.py index 3a028331d..806b8511f 100644 --- a/aframe/tasks/infer/__init__.py +++ b/aframe/tasks/infer/__init__.py @@ -1 +1 @@ -from .infer import Infer +from .infer import Infer, InferRnP diff --git a/aframe/tasks/infer/base.py b/aframe/tasks/infer/base.py index 46766a769..f3a4235d0 100644 --- a/aframe/tasks/infer/base.py +++ b/aframe/tasks/infer/base.py @@ -1,6 +1,7 @@ import json import os import warnings +from abc import abstractmethod from pathlib import Path import h5py @@ -115,77 +116,34 @@ def timeseries_output(self): def metadata_output(self): return self.tmp_dir / "metadata.json" - @property - def background_fnames(self): - return self.workflow_input()["data"].collection.targets.values() + def workflow_requires(self): + reqs = {"model_repository": ExportLocal.req(self)} + reqs.update(self._workflow_requires()) + return reqs - @property - def injection_set_fname(self): - return self.workflow_input()["waveforms"][0].path + @abstractmethod + def _workflow_requires(self): + pass - def workflow_requires(self): - reqs = {} - reqs["model_repository"] = ExportLocal.req(self) - testing_waveforms = TestingWaveforms.req(self) - fetch = testing_waveforms.requires().workflow_requires()[ - "test_segments" - ] - reqs["data"] = fetch - reqs["waveforms"] = testing_waveforms + @abstractmethod + def _workflow_condition(self) -> bool: + pass - return reqs + @abstractmethod + def _create_branch_map(self): + pass - def get_num_shifts(self): - # calculate the number of shifts required - # to accumulate the requested background, - # given the duration of the background segments - segments = data_utils.segments_from_paths(self.background_fnames) - num_shifts = data_utils.get_num_shifts_from_Tb( - segments, - self.Tb, - max(self.shifts), - self.psd_length, - ) - return num_shifts + @abstractmethod + def create_sequence(self): + pass @law.dynamic_workflow_condition(cache_met_condition=True) def workflow_condition(self) -> bool: - return self.workflow_input()["data"].collection.exists() + return self._workflow_condition() @workflow_condition.create_branch_map def create_branch_map(self): - # create the individual fname shift - # combinations that represent individual - # condor inference jobs to be submitted - branch_map = {} - num_shifts = self.get_num_shifts() - counter = 0 - for fname in self.background_fnames: - fname = Path(fname.path) - start, duration = map(float, fname.stem.split("-")[-2:]) - stop = start + duration - - if self.zero_lag: - # check if segment is long enough to be analyzed - if data_utils.is_analyzeable_segment( - start, stop, [0] * len(self.shifts), self.psd_length - ): - _shifts = [0 for s in self.shifts] - branch_map[counter] = (fname, _shifts) - counter += 1 - - if num_shifts > 0: - for i in range(num_shifts): - _shifts = [s * (i + 1) for s in self.shifts] - # check if segment is long enough to be analyzed - if data_utils.is_analyzeable_segment( - start, stop, _shifts, self.psd_length - ): - # unique identifier for mapping to branch map - branch_map[counter] = (fname, _shifts) - counter += 1 - - return branch_map + return self._create_branch_map() def get_ip_address(self) -> str: raise NotImplementedError @@ -203,22 +161,12 @@ def output(self): def run(self): from hermes.aeriel.client import InferenceClient - from infer.data import Sequence from infer.main import infer from infer.postprocess import Postprocessor ip = os.getenv("AFRAME_TRITON_IP") self.tmp_dir.mkdir(exist_ok=True, parents=True) - fname, shifts = self.branch_data - sequence = Sequence( - ifos=self.ifos, - batch_size=self.batch_size, - inference_sampling_rate=self.inference_sampling_rate, - rate=self.rate_per_client, - shifts=shifts, - background_fname=fname, - injection_set_fname=self.injection_set_fname, - ) + sequence, shifts = self.create_sequence() postprocessor = Postprocessor( integration_window_length=self.integration_window_length, @@ -267,3 +215,138 @@ def run(self): } with open(self.metadata_output, "w") as f: json.dump(metadata, f) + + +@inherits(InferParameters) +class Hdf5InferBase(InferBase): + @property + def background_fnames(self): + return self.workflow_input()["data"].collection.targets.values() + + @property + def injection_set_fname(self): + return self.workflow_input()["waveforms"][0].path + + def _workflow_requires(self): + reqs = {} + testing_waveforms = TestingWaveforms.req(self) + fetch = testing_waveforms.requires().workflow_requires()[ + "test_segments" + ] + reqs["data"] = fetch + reqs["waveforms"] = testing_waveforms + return reqs + + def get_num_shifts(self): + # calculate the number of shifts required + # to accumulate the requested background, + # given the duration of the background segments + segments = data_utils.segments_from_paths(self.background_fnames) + num_shifts = data_utils.get_num_shifts_from_Tb( + segments, + self.Tb, + max(self.shifts), + self.psd_length, + ) + return num_shifts + + def _workflow_condition(self) -> bool: + return self.workflow_input()["data"].collection.exists() + + def _create_branch_map(self): + # create the individual fname shift + # combinations that represent individual + # condor inference jobs to be submitted + branch_map = {} + num_shifts = self.get_num_shifts() + counter = 0 + for fname in self.background_fnames: + fname = Path(fname.path) + start, duration = map(float, fname.stem.split("-")[-2:]) + stop = start + duration + + if self.zero_lag: + # check if segment is long enough to be analyzed + if data_utils.is_analyzeable_segment( + start, stop, [0] * len(self.shifts), self.psd_length + ): + _shifts = [0 for _ in self.shifts] + branch_map[counter] = (fname, _shifts) + counter += 1 + + if num_shifts > 0: + for i in range(num_shifts): + _shifts = [s * (i + 1) for s in self.shifts] + # check if segment is long enough to be analyzed + if data_utils.is_analyzeable_segment( + start, stop, _shifts, self.psd_length + ): + # unique identifier for mapping to branch map + branch_map[counter] = (fname, _shifts) + counter += 1 + + return branch_map + + def create_sequence(self): + from infer.data import Hdf5Sequence + + fname, shifts = self.branch_data + sequence = Hdf5Sequence( + ifos=self.ifos, + batch_size=self.batch_size, + inference_sampling_rate=self.inference_sampling_rate, + rate=self.rate_per_client, + shifts=shifts, + background_fname=str(fname), + injection_set_fname=str(self.injection_set_fname), + ) + return sequence, shifts + + +@inherits(InferParameters) +class RnPInferBase(InferBase): + injection_files = luigi.ListParameter(default=[]) + channel = luigi.Parameter( + description="Name of channel within injection files" + ) + sample_rate = luigi.FloatParameter(default=2048.0) + files_per_job = luigi.IntParameter( + default=1, + description="Number of injection files to process per condor job.", + ) + + def _workflow_requires(self): + return {} + + def _workflow_condition(self) -> bool: + return bool(self.injection_files) and all( + Path(f).exists() for f in self.injection_files + ) + + def _create_branch_map(self): + if not self.injection_files: + return {} + + shifts = [0.0 for _ in self.ifos] + files = sorted(self.injection_files) + branch_map = {} + for i, start in enumerate(range(0, len(files), self.files_per_job)): + end = min(start + self.files_per_job, len(files)) + chunk = files[start:end] + branch_map[i] = (chunk, shifts) + return branch_map + + def create_sequence(self): + from infer.data import RnPSequence + + injection_files, shifts = self.branch_data + sequence = RnPSequence( + injection_files=injection_files, + channel=self.channel, + ifos=self.ifos, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + rate=self.rate_per_client, + ) + return sequence, shifts diff --git a/aframe/tasks/infer/infer.py b/aframe/tasks/infer/infer.py index e92808099..978263c3d 100644 --- a/aframe/tasks/infer/infer.py +++ b/aframe/tasks/infer/infer.py @@ -16,11 +16,14 @@ from luigi.util import inherits from aframe.base import AframeSingularityTask -from aframe.tasks.infer.base import InferBase, InferParameters +from aframe.tasks.infer.base import ( + Hdf5InferBase, + InferParameters, + RnPInferBase, +) -@inherits(InferParameters) -class DeployInferLocal(InferBase): +class _DeployInferLocalMixin: """ Launch inference on local gpus """ @@ -103,11 +106,20 @@ def __exit__(self, *args): return ServerContext(self) -@inherits(DeployInferLocal) -class Infer(AframeSingularityTask): +@inherits(InferParameters) +class DeployInferHdf5Local(_DeployInferLocalMixin, Hdf5InferBase): + pass + + +@inherits(InferParameters) +class DeployInferRnPLocal(_DeployInferLocalMixin, RnPInferBase): + pass + + +class _InferAggregateBase: """ - Law Task that aggregates results from - individual condor inference jobs + Shared aggregation logic for Infer and InferRnP. + Not meant to be instantiated directly. """ remove_tmpdir = luigi.BoolParameter( @@ -124,32 +136,16 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.foreground_output = self.output_dir / "foreground.hdf5" self.background_output = self.output_dir / "background.hdf5" - self.zero_lag_output = self.output_dir / "0lag.hdf5" self.timeseries_output = self.output_dir / "timeseries.hdf5" def output(self): output = {} output["foreground"] = law.LocalFileTarget(self.foreground_output) output["background"] = law.LocalFileTarget(self.background_output) - if self.zero_lag: - output["zero_lag"] = law.LocalFileTarget(self.zero_lag_output) if self.return_timeseries: output["timeseries"] = law.LocalFileTarget(self.timeseries_output) return output - def requires(self): - # deploy the condor inference jobs; - # reduce job status poll interval - # so that jobs can be submitted faster - return DeployInferLocal.req( - self, - request_memory=self.request_memory, - request_disk=self.request_disk, - request_cpus=self.request_cpus, - workflow=self.workflow, - poll_interval=0.2, - ) - @property def targets(self): return list(self.input().collection.targets.values()) @@ -183,7 +179,7 @@ def timeseries_files(self): def get_metadata(self): """ Read in shift and length metadata from the metadata - files created by each `DeployInferLocal` condor job. + files created by each condor job. This data is read from the metadata files rather than the hdf5 files because the read operation is O(1000) times faster this way @@ -192,7 +188,7 @@ def get_metadata(self): num_files = len(files) background_lengths = np.zeros(num_files) foreground_lengths = np.zeros(num_files) - shifts = np.zeros((num_files, len(self.shifts))) + shifts = np.zeros((num_files, len(self.ifos))) for i, f in enumerate(files): with open(f, "r") as f: data = json.load(f) @@ -230,6 +226,37 @@ def aggregate_timeseries(self): index_array = np.array(index, dtype=dtype) f.create_dataset("index", data=index_array) + +@inherits(DeployInferHdf5Local) +class Infer(_InferAggregateBase, AframeSingularityTask): + """ + Law Task that aggregates results from + individual condor inference jobs + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.zero_lag_output = self.output_dir / "0lag.hdf5" + + def output(self): + output = super().output() + if self.zero_lag: + output["zero_lag"] = law.LocalFileTarget(self.zero_lag_output) + return output + + def requires(self): + # deploy the condor inference jobs; + # reduce job status poll interval + # so that jobs can be submitted faster + return DeployInferHdf5Local.req( + self, + request_memory=self.request_memory, + request_disk=self.request_disk, + request_cpus=self.request_cpus, + workflow=self.workflow, + poll_interval=0.2, + ) + def run(self): import shutil @@ -283,3 +310,57 @@ def run(self): if self.remove_tmpdir: shutil.rmtree(self.output_dir / "tmp") + + +@inherits(DeployInferRnPLocal) +class InferRnP(_InferAggregateBase, AframeSingularityTask): + """ + Law Task that aggregates results from + individual condor inference jobs for RnP inputs. + """ + + def requires(self): + return DeployInferRnPLocal.req( + self, + request_memory=self.request_memory, + request_disk=self.request_disk, + request_cpus=self.request_cpus, + workflow=self.workflow, + poll_interval=0.2, + ) + + def run(self): + import shutil + + from ledger.events import EventSet + + background_lengths, foreground_lengths, _ = self.get_metadata() + background_length = sum(background_lengths) + foreground_length = sum(foreground_lengths) + foreground_mask = foreground_lengths > 0 + + logging.info("Aggregating background files") + EventSet.aggregate( + self.background_files, + self.background_output, + clean=False, + length=background_length, + ) + logging.info("Aggregating foreground files") + EventSet.aggregate( + self.foreground_files[foreground_mask], + self.foreground_output, + clean=False, + length=foreground_length, + ) + + if self.return_timeseries: + logging.info("Aggregating timeseries files") + self.aggregate_timeseries() + + background = EventSet.read(self.background_output) + background = background.sort_by("detection_statistic") + background.write(self.background_output) + + if self.remove_tmpdir: + shutil.rmtree(self.output_dir / "tmp") From d16aacebc67affa6ff63d4261ef3a4d7d4aa0b70 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 17 Mar 2026 08:17:31 -0700 Subject: [PATCH 3/6] Pass file directory rather than list of files --- aframe/tasks/infer/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/aframe/tasks/infer/base.py b/aframe/tasks/infer/base.py index f3a4235d0..407860c91 100644 --- a/aframe/tasks/infer/base.py +++ b/aframe/tasks/infer/base.py @@ -305,7 +305,7 @@ def create_sequence(self): @inherits(InferParameters) class RnPInferBase(InferBase): - injection_files = luigi.ListParameter(default=[]) + injection_file_dir = luigi.PathParameter() channel = luigi.Parameter( description="Name of channel within injection files" ) @@ -315,12 +315,18 @@ class RnPInferBase(InferBase): description="Number of injection files to process per condor job.", ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.injection_files = list( + Path(self.injection_file_dir).glob("*.hdf") + ) + def _workflow_requires(self): return {} def _workflow_condition(self) -> bool: return bool(self.injection_files) and all( - Path(f).exists() for f in self.injection_files + f.exists() for f in self.injection_files ) def _create_branch_map(self): From 6801348ccfe6287eea1d6e0e3f8948518552ac97 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 17 Mar 2026 08:18:36 -0700 Subject: [PATCH 4/6] Remove unneeded end parameter --- aframe/tasks/infer/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aframe/tasks/infer/base.py b/aframe/tasks/infer/base.py index 407860c91..4530cf8cd 100644 --- a/aframe/tasks/infer/base.py +++ b/aframe/tasks/infer/base.py @@ -337,8 +337,7 @@ def _create_branch_map(self): files = sorted(self.injection_files) branch_map = {} for i, start in enumerate(range(0, len(files), self.files_per_job)): - end = min(start + self.files_per_job, len(files)) - chunk = files[start:end] + chunk = files[start : start + self.files_per_job] branch_map[i] = (chunk, shifts) return branch_map From 18cc75f05c9b19aebb3cd1aa0000980f35af660c Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 17 Mar 2026 12:50:19 -0700 Subject: [PATCH 5/6] Bug fixes --- aframe/tasks/infer/base.py | 12 +- projects/infer/infer/data.py | 10 +- projects/infer/infer/main.py | 2 +- projects/infer/pyproject.toml | 1 + projects/infer/uv.lock | 348 ++++++++++++++++++++++++++++++++-- 5 files changed, 354 insertions(+), 19 deletions(-) diff --git a/aframe/tasks/infer/base.py b/aframe/tasks/infer/base.py index 4530cf8cd..8372dfef7 100644 --- a/aframe/tasks/infer/base.py +++ b/aframe/tasks/infer/base.py @@ -14,7 +14,7 @@ from aframe.base import AframeSingularityTask from aframe.config import paths from aframe.parameters import PathParameter -from aframe.tasks import ExportLocal, TestingWaveforms +from aframe.tasks import ExportLocal, TestingWaveforms, Train from aframe.tasks.data.condor.workflows import StaticMemoryWorkflow @@ -117,7 +117,11 @@ def metadata_output(self): return self.tmp_dir / "metadata.json" def workflow_requires(self): - reqs = {"model_repository": ExportLocal.req(self)} + reqs = { + "model_repository": ExportLocal.req( + self, train_task=self.train_task + ) + } reqs.update(self._workflow_requires()) return reqs @@ -305,6 +309,8 @@ def create_sequence(self): @inherits(InferParameters) class RnPInferBase(InferBase): + train_task = Train + injection_file_dir = luigi.PathParameter() channel = luigi.Parameter( description="Name of channel within injection files" @@ -318,7 +324,7 @@ class RnPInferBase(InferBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.injection_files = list( - Path(self.injection_file_dir).glob("*.hdf") + Path(self.injection_file_dir).rglob("*.gwf") ) def _workflow_requires(self): diff --git a/projects/infer/infer/data.py b/projects/infer/infer/data.py index 567f1428e..d1b0a93aa 100644 --- a/projects/infer/infer/data.py +++ b/projects/infer/infer/data.py @@ -3,6 +3,7 @@ import re from abc import ABC, abstractmethod from contextlib import nullcontext +from pathlib import Path from typing import Optional from zlib import adler32 @@ -207,6 +208,7 @@ def _setup( shifts: list[float], ): self.background_fname = background_fname + self.inference_filenames = [background_fname] self.ifos = ifos if len(ifos) != len(shifts): @@ -308,7 +310,7 @@ def recover(self, foreground: EventSet) -> RecoveredInjectionSet: class RnPSequence(BaseSequence): def __init__( self, - injection_files: list[str], + injection_files: list[Path], channel: str, ifos: list[str], sample_rate: float, @@ -355,7 +357,7 @@ def __init__( def _setup( self, - injection_files: list[str], + injection_files: list[Path], channel: str, ifos: list[str], ): @@ -370,7 +372,9 @@ def _setup( injection_files = sorted(injection_files) - matches = [FNAME_RE.search(fname) for fname in injection_files] + self.inference_filenames = [fname.name for fname in injection_files] + + matches = [FNAME_RE.search(fname.name) for fname in injection_files] if not all(matches): raise ValueError( "All injection files must match expected name pattern" diff --git a/projects/infer/infer/main.py b/projects/infer/infer/main.py index ce2250b6c..33686061c 100644 --- a/projects/infer/infer/main.py +++ b/projects/infer/infer/main.py @@ -38,7 +38,7 @@ def infer( "at GPS time {}".format( sequence.id, sequence.duration, - sequence.background_fname, + sequence.inference_filenames, sequence.shifts / sequence.sample_rate, sequence.sample_rate, sequence.t0, diff --git a/projects/infer/pyproject.toml b/projects/infer/pyproject.toml index eee5b12ac..954003b36 100644 --- a/projects/infer/pyproject.toml +++ b/projects/infer/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "utils", "ledger", "urllib3>=1.25.4,<1.27", + "gwpy", ] [project.scripts] diff --git a/projects/infer/uv.lock b/projects/infer/uv.lock index 5502aeeb5..34b05aea0 100644 --- a/projects/infer/uv.lock +++ b/projects/infer/uv.lock @@ -3,7 +3,8 @@ revision = 3 requires-python = ">=3.10, <3.13" resolution-markers = [ "python_full_version >= '3.12'", - "python_full_version < '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", ] [[package]] @@ -245,12 +246,15 @@ wheels = [ name = "astropy" version = "6.0.1" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] dependencies = [ - { name = "astropy-iers-data" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "pyerfa" }, - { name = "pyyaml" }, + { name = "astropy-iers-data", version = "0.2025.2.10.0.33.26", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", marker = "python_full_version < '3.11'" }, + { name = "packaging", marker = "python_full_version < '3.11'" }, + { name = "pyerfa", marker = "python_full_version < '3.11'" }, + { name = "pyyaml", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/48/08/f205a24d75ad1f329586bb685b53574c5303c56acf80924166a6c8df8a09/astropy-6.0.1.tar.gz", hash = "sha256:89a975de356d0608e74f1f493442fb3acbbb7a85b739e074460bb0340014b39c", size = 7074537, upload-time = "2024-03-26T18:30:03.408Z" } wheels = [ @@ -277,15 +281,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/f1/451851b269855c50fb468d90272adcaf10a643ccf2da6433f8e153fd48d8/astropy-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:8fbd6d88935749ae892445691ac0dbd1923fc6d8094753a35150fc7756118fe3", size = 6363162, upload-time = "2024-03-26T18:29:34.61Z" }, ] +[[package]] +name = "astropy" +version = "7.2.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "astropy-iers-data", version = "0.2026.3.16.0.53.33", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", marker = "python_full_version >= '3.11'" }, + { name = "packaging", marker = "python_full_version >= '3.11'" }, + { name = "pyerfa", marker = "python_full_version >= '3.11'" }, + { name = "pyyaml", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7b/92/2dce2d48347efc3346d08ca7995b152d242ebd170c571f7c9346468d8427/astropy-7.2.0.tar.gz", hash = "sha256:ae48bc26b1feaeb603cd94bd1fa1aa39137a115fe931b7f13787ab420e8c3070", size = 7057774, upload-time = "2025-11-25T22:36:41.916Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/6d/6330a844bad8dfc4875e0f2fa1db1fee87837ba9805aa8a8d048c071363a/astropy-7.2.0-cp311-abi3-macosx_10_9_x86_64.whl", hash = "sha256:efac04df4cc488efe630c2fff1992d6516dfb16a06e197fb68bc9e8e3b85def1", size = 6442332, upload-time = "2025-11-25T22:36:23.6Z" }, + { url = "https://files.pythonhosted.org/packages/a6/ba/3418133ba144dfcd1530bca5a6b695f4cdd21a8abaaa2ac4e5450d11b028/astropy-7.2.0-cp311-abi3-macosx_11_0_arm64.whl", hash = "sha256:52e9a7d9c86b21f1af911a2930cd0c4a275fb302d455c89e11eedaffef6f2ad0", size = 6413656, upload-time = "2025-11-25T22:36:26.548Z" }, + { url = "https://files.pythonhosted.org/packages/be/ba/05e43b5a7d738316a097fa78524d3eaaff5986294b4a052d4adb3c45e7c0/astropy-7.2.0-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97c370421b9bb13d4c762c7af06d172bad7c01bd5bcf88314f6913c3c235b770", size = 9758867, upload-time = "2025-11-25T22:36:28.661Z" }, + { url = "https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f39ce2c80211fbceb005d377a5478cd0d66c42aa1498d252f2239fe5a025c24", size = 9789007, upload-time = "2025-11-25T22:36:31.063Z" }, + { url = "https://files.pythonhosted.org/packages/f8/fb/e4d35194a5009d7a73333079481a4ef1380a255d67b9c1db578151a5fb50/astropy-7.2.0-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ad4d71db994d45f046a1a5449000cf0f88ab6367cb67658500654a0586d6ab19", size = 9748547, upload-time = "2025-11-25T22:36:33.154Z" }, + { url = "https://files.pythonhosted.org/packages/36/ea/f990730978ae0a7a34705f885d2f3806928c5f0bc22eefd6a1a23539cc32/astropy-7.2.0-cp311-abi3-win32.whl", hash = "sha256:95161f26602433176483e8bde8ab1a8ca09148f5b4bf5190569a26d381091598", size = 6237228, upload-time = "2025-11-25T22:36:35.236Z" }, + { url = "https://files.pythonhosted.org/packages/ec/bc/f4378f586dd63902c37d16f68f35f7d555b3b32e08ac6b1d633eb0a48805/astropy-7.2.0-cp311-abi3-win_amd64.whl", hash = "sha256:dc7c340ba1713e55c93071b32033f3153470a0f663a4d539c03a7c9b44020790", size = 6362868, upload-time = "2025-11-25T22:36:37.784Z" }, + { url = "https://files.pythonhosted.org/packages/77/79/b6d4bf01913cfd4ce0cd4c1be5916beccdb92b2970bab8c827984231eae6/astropy-7.2.0-cp311-abi3-win_arm64.whl", hash = "sha256:0c428735a3f15b05c2095bc6ccb5f98a64bc99fb7015866af19ff8492420ddaf", size = 6221756, upload-time = "2025-11-25T22:36:39.852Z" }, +] + [[package]] name = "astropy-iers-data" version = "0.2025.2.10.0.33.26" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] sdist = { url = "https://files.pythonhosted.org/packages/28/3b/f9039a7062715eef8ac77aa886da69364e663a054ef836a64cbf18a048ec/astropy_iers_data-0.2025.2.10.0.33.26.tar.gz", hash = "sha256:03d93817588ef2344e22d56f7a11cba2ecd877ddb2d0fc259a1daf3980c33c3e", size = 1892096, upload-time = "2025-02-10T00:34:11.776Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/85/41b056b6ea33811539a1ff6595c6bd9d598b247402c0b351aee5efb37ef2/astropy_iers_data-0.2025.2.10.0.33.26-py3-none-any.whl", hash = "sha256:e55fb8578bc3c5e54113aae624f94e111bb89bdb57220958c7e673784b5c3b68", size = 1944899, upload-time = "2025-02-10T00:34:09.378Z" }, ] +[[package]] +name = "astropy-iers-data" +version = "0.2026.3.16.0.53.33" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/11/5e/48429424e4b972d5ebceb46f578b3d7848aee0a75bc78acb18b4f3967ded/astropy_iers_data-0.2026.3.16.0.53.33.tar.gz", hash = "sha256:8da3b6c56573cf63ec99c0e7b4ab74be7dc5af2aaa4a62fb671879f7411acbb6", size = 1928977, upload-time = "2026-03-16T00:54:24.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/7f/223a0d2ba7ea0c798b59e5a0f61782347cf9292274a1f7f3eef0a02c1536/astropy_iers_data-0.2026.3.16.0.53.33-py3-none-any.whl", hash = "sha256:f8e118ace0727540131384fe5a07fddbab970a9a368fb47e46e7ca7166b9557c", size = 1985352, upload-time = "2026-03-16T00:54:22.885Z" }, +] + [[package]] name = "asttokens" version = "3.0.0" @@ -553,6 +600,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767, upload-time = "2024-12-24T18:12:32.852Z" }, ] +[[package]] +name = "click" +version = "8.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" }, +] + [[package]] name = "cloudpathlib" version = "0.18.1" @@ -734,6 +793,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/6b/7c87867d255cbce8167ed99fc65635e9395d2af0f0c915428f5b17ec412d/Cython-3.0.12-py2.py3-none-any.whl", hash = "sha256:0038c9bae46c459669390e53a1ec115f8096b2e4647ae007ff1bf4e6dee92806", size = 1171640, upload-time = "2025-02-11T09:05:45.648Z" }, ] +[[package]] +name = "dateparser" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "regex" }, + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/2c/668dfb8c073a5dde3efb80fa382de1502e3b14002fd386a8c1b0b49e92a9/dateparser-1.3.0.tar.gz", hash = "sha256:5bccf5d1ec6785e5be71cc7ec80f014575a09b4923e762f850e57443bddbf1a5", size = 337152, upload-time = "2026-02-04T16:00:06.162Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9a/c7/95349670e193b2891176e1b8e5f43e12b31bff6d9994f70e74ab385047f6/dateparser-1.3.0-py3-none-any.whl", hash = "sha256:8dc678b0a526e103379f02ae44337d424bd366aac727d3c6cf52ce1b01efbb5a", size = 318688, upload-time = "2026-02-04T16:00:04.652Z" }, +] + [[package]] name = "debugpy" version = "1.8.12" @@ -773,6 +847,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "dqsegdb2" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "igwn-auth-utils" }, + { name = "igwn-segments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/37/7874b39abede48fe05c3b5a25092e419e521145e9705c97b711965e5f05d/dqsegdb2-1.3.0.tar.gz", hash = "sha256:4e291899cd395daf5913c48a835d213ef20833d78354871cb848988cc2cc8ae4", size = 33661, upload-time = "2025-01-08T16:36:58.072Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/67/c977e7307d432911fe28acf55dba01f84f3a0a4a2a968350195c0f52d23e/dqsegdb2-1.3.0-py3-none-any.whl", hash = "sha256:c93087da332b7d91519a370abed3cfc91b64a6bf393ec64d79c2747c837aeb66", size = 27957, upload-time = "2025-01-08T16:36:56.881Z" }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1063,15 +1151,106 @@ wheels = [ name = "gwdatafind" version = "1.2.0" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] dependencies = [ - { name = "igwn-auth-utils" }, - { name = "ligo-segments" }, + { name = "igwn-auth-utils", marker = "python_full_version < '3.11'" }, + { name = "ligo-segments", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/2b/10/9f1b9100f59e2ca4a85dad8a21942d0702d756f4b80a433c728be4a871d2/gwdatafind-1.2.0.tar.gz", hash = "sha256:8f74942e66cdb9a53030da29069110b3cb30afc2a034790957786028fb09f451", size = 40381, upload-time = "2023-12-18T09:19:43.234Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/89/61/5020eff070e04b1e07c7cf8bed63705aa705011e057cbb839e9a31367bdd/gwdatafind-1.2.0-py3-none-any.whl", hash = "sha256:58c505ee188c1186ff81b3de5f946f289179a4f8c334f7eb45d07dd70a71bd2c", size = 45529, upload-time = "2023-12-18T09:19:41.129Z" }, ] +[[package]] +name = "gwdatafind" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "igwn-auth-utils", marker = "python_full_version >= '3.11'" }, + { name = "igwn-segments", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/dd/4f517a2a36f71d5bb74b8b0796e3fec87703587503d4583f4a5b962ae60b/gwdatafind-2.1.1.tar.gz", hash = "sha256:e4710256daa7b47e901da2f2846620c551e9caaaaf22b7773c81a8ae052da43e", size = 41311, upload-time = "2025-10-31T09:46:57.9Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/00/07a1a9473ea5bbbc0be3cea79778dfd31e7e199404777907930badc9f399/gwdatafind-2.1.1-py3-none-any.whl", hash = "sha256:6e6d430fa243e6241ca0c214f1916f7973cf1937716bbca52d99a9b88650faeb", size = 45178, upload-time = "2025-10-31T09:46:56.468Z" }, +] + +[[package]] +name = "gwosc" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/4e/6756e30841a81540ea38d429892f0900c9a33fa5c6269697fd021c20f018/gwosc-0.8.1.tar.gz", hash = "sha256:d4a890147ffbd76bfa800f6d01dd614cfb85b7623c537b93d713656144a9edff", size = 35494, upload-time = "2025-06-04T20:56:20.341Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/63/5dc667739398c194679cb44813e1e180fd2dff843155c1c7d8391276f2bc/gwosc-0.8.1-py3-none-any.whl", hash = "sha256:e6040e1901f8b8e43a7ed292f92a17eb67490a85124c93d45b2ab50ae4000b41", size = 32222, upload-time = "2025-06-04T20:56:19.034Z" }, +] + +[[package]] +name = "gwpy" +version = "3.0.14" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "astropy", version = "6.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "dateparser", marker = "python_full_version < '3.11'" }, + { name = "dqsegdb2", marker = "python_full_version < '3.11'" }, + { name = "gwdatafind", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "gwosc", marker = "python_full_version < '3.11'" }, + { name = "h5py", marker = "python_full_version < '3.11'" }, + { name = "igwn-segments", marker = "python_full_version < '3.11'" }, + { name = "ligotimegps", version = "2.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "matplotlib", marker = "python_full_version < '3.11'" }, + { name = "numpy", marker = "python_full_version < '3.11'" }, + { name = "packaging", marker = "python_full_version < '3.11'" }, + { name = "python-dateutil", marker = "python_full_version < '3.11'" }, + { name = "requests", marker = "python_full_version < '3.11'" }, + { name = "scipy", marker = "python_full_version < '3.11'" }, + { name = "tqdm", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/cd/bb165e47b18cd255c1ee1fcd59802544d8de57566b449cfdaf848b69108f/gwpy-3.0.14.tar.gz", hash = "sha256:bf25d19763c9128f515144349d3f70925a8334bbd22950c01d35bb954f323f52", size = 1543294, upload-time = "2026-01-16T13:02:28.432Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/46/c893a66bd97f08ddf4a7cfde52cc58fb8e2c83acfa915f3c02c00f0aa721/gwpy-3.0.14-py3-none-any.whl", hash = "sha256:f61ed9d4b3eba7f9f534f72f03956f158ed7f9101726b5387c29fe57ce45c77e", size = 1395825, upload-time = "2026-01-16T13:02:26.55Z" }, +] + +[[package]] +name = "gwpy" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "astropy", version = "7.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "dateparser", marker = "python_full_version >= '3.11'" }, + { name = "dqsegdb2", marker = "python_full_version >= '3.11'" }, + { name = "gwdatafind", version = "2.1.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "gwosc", marker = "python_full_version >= '3.11'" }, + { name = "h5py", marker = "python_full_version >= '3.11'" }, + { name = "igwn-segments", marker = "python_full_version >= '3.11'" }, + { name = "ligotimegps", version = "2.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "matplotlib", marker = "python_full_version >= '3.11'" }, + { name = "numpy", marker = "python_full_version >= '3.11'" }, + { name = "packaging", marker = "python_full_version >= '3.11'" }, + { name = "python-dateutil", marker = "python_full_version >= '3.11'" }, + { name = "requests", marker = "python_full_version >= '3.11'" }, + { name = "scipy", marker = "python_full_version >= '3.11'" }, + { name = "tqdm", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/71/1e652734a0e8f23e0f9dbf5df5698e66fa49aef4446d2c6c9eb914921192/gwpy-4.0.1.tar.gz", hash = "sha256:477aa69bd40506bfb0df5ec779d9dd6fef562a3bfabe58a52b297e25c57f61f0", size = 1615849, upload-time = "2026-02-03T10:14:59.045Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/ac/1debaaec791da4475687d22c8c2cbccc2008885046754f5e9f65997886b8/gwpy-4.0.1-py3-none-any.whl", hash = "sha256:587eb66443e41450bd89bf64891e323943afad3be385fccc1645eda2710908b7", size = 1559010, upload-time = "2026-02-03T10:14:56.928Z" }, +] + [[package]] name = "h11" version = "0.14.0" @@ -1159,12 +1338,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c5/7d/07b9b5e6421362a27db4fdeac60e211fc07d47b94c2085bcfc2cd76192ae/igwn_auth_utils-1.1.1-py3-none-any.whl", hash = "sha256:f995d79214afbcb05823d46b33a9fd96cfa7734431a2ca5beeddb09c0452da83", size = 26712, upload-time = "2024-09-09T10:42:05.178Z" }, ] +[[package]] +name = "igwn-segments" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/92/ef/3500687ef4a61887bc218ab7ef8352ab7110afababaa82e958f2ca95967b/igwn_segments-2.1.1.tar.gz", hash = "sha256:672814d78f5de7582b5f39189cdc6016d7ffde95d2b8b8b8b82aa9b45e53f3dd", size = 60805, upload-time = "2026-01-19T17:51:05.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/87/f5dc189af8817fc722dfac1798bf0d70a4b795701be4adafd215a45e3b8d/igwn_segments-2.1.1-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:185993971cf4cac593156f47361ff62f51158bc197f2cdbad6f8e0165bae43dd", size = 49090, upload-time = "2026-01-19T17:50:25.921Z" }, + { url = "https://files.pythonhosted.org/packages/51/2b/cf74ffaff55de002a5717c3843c5c40b019587a8f55c26c62d1fdf6f9c7b/igwn_segments-2.1.1-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:48867076f643508cb9f66fd460d72548ea60dab5192f49770d77883be7f00275", size = 49251, upload-time = "2026-01-19T17:50:27.427Z" }, + { url = "https://files.pythonhosted.org/packages/8d/d7/42cda68c97f11532dc8c8a7db93ee9593ef9c0f52b5d26db5dfe06934155/igwn_segments-2.1.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:edd31f115f8cc3411491300f56279be3417afde486b6cab6ee80c43c1bae5034", size = 108410, upload-time = "2026-01-19T17:50:28.603Z" }, + { url = "https://files.pythonhosted.org/packages/4c/b5/fdea6bf16f756811c6c9225e1e9c79ddf57a37c1d9609fb374afb8912573/igwn_segments-2.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76b5626ed1f81197333fdafeac262c15339c5e5e21a92a1e4254b01487fd49ad", size = 109833, upload-time = "2026-01-19T17:50:29.948Z" }, + { url = "https://files.pythonhosted.org/packages/1e/16/1a272ce8fe797a962d6a66b2c5a43528df3b5368ef432e5537785d2a617b/igwn_segments-2.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:74b765df7d2ac6e46d48a74ee8dda6663748e9341b6d06f7f5968443564b5822", size = 108421, upload-time = "2026-01-19T17:50:30.933Z" }, + { url = "https://files.pythonhosted.org/packages/91/47/ac87852a682e548ebc926fd97975b92352f63b54943a96b6bbcb7ffa7d1a/igwn_segments-2.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:22abde65098e53b41a299f0e7606b23671e07cae87128c95dfb0baf80f0ff6a8", size = 107778, upload-time = "2026-01-19T17:50:31.951Z" }, + { url = "https://files.pythonhosted.org/packages/bf/c3/3ba15630cf8b873cb10bf188793b7828cf2f7e13ba0bd5673bdad42e8636/igwn_segments-2.1.1-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:774c84a5c0ebf4fe6e4de3e1b8c73c02b6f53db96598a07e7d26de9cd156ef64", size = 49068, upload-time = "2026-01-19T17:50:32.953Z" }, + { url = "https://files.pythonhosted.org/packages/b4/b1/7ffbe8b7fe6b6597d913bfd32db3898ee3d1495003bfedd497b14ea696b9/igwn_segments-2.1.1-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:8cc0dde1d4f9de3a36eddf38e25baa166492737c93ad71658db25dfc46745ddd", size = 49249, upload-time = "2026-01-19T17:50:34.158Z" }, + { url = "https://files.pythonhosted.org/packages/71/f7/9110ee0e83f43ba97155897be0a864d53801ff259a674bc8e37f80e854b9/igwn_segments-2.1.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6c2231f6a8e2adeb6403e4ba5fa8fb9cd9bcb7c55fa9e67e90139020a32aff79", size = 112480, upload-time = "2026-01-19T17:50:35.064Z" }, + { url = "https://files.pythonhosted.org/packages/50/06/eb25c8ffefcbe8c9f3275855bbfad607e20d61ea6c3f5be778f8b5c0d82f/igwn_segments-2.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cfbef53fcd95c56e200b9211cd3bad4fa48c5b82586166c6fdb44306ea3b9289", size = 114264, upload-time = "2026-01-19T17:50:36.092Z" }, + { url = "https://files.pythonhosted.org/packages/84/18/ce2b83e8993ddb4d8f6c409ff3405f3e424215a88876b019f7a0d9b38d70/igwn_segments-2.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5ba2483b066c8f95e6e916437f7a375c0497313e1a425642b89ef1d647980f80", size = 112686, upload-time = "2026-01-19T17:50:37.479Z" }, + { url = "https://files.pythonhosted.org/packages/13/78/ec606472610da3f545cbdd03b134fc09cfb86eae3674f5198ee8a4c36501/igwn_segments-2.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5351901272f8a33b4cb8546930e744bcd912c5d899879e8477f064cd11d0d874", size = 111852, upload-time = "2026-01-19T17:50:38.627Z" }, + { url = "https://files.pythonhosted.org/packages/71/78/6764ceb793728334df4dc87d9089ca5a8010ad280c7ca13c0e4304818f83/igwn_segments-2.1.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:d8a1e36f21f9420f254a7f78c289035c8d9dc4e0f44b5a77495e87d71469bd7b", size = 49329, upload-time = "2026-01-19T17:50:39.843Z" }, + { url = "https://files.pythonhosted.org/packages/0c/76/eb614178340cbae66af617303f80748d61d74c57f0b02105b9ca85c90033/igwn_segments-2.1.1-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:139abaec24edc25e7409fe48445c5aa8d9448e98c038079d84485497abc838c0", size = 49704, upload-time = "2026-01-19T17:50:41.013Z" }, + { url = "https://files.pythonhosted.org/packages/e0/5c/f10140682e77bdf4160cd474bbc21a8e8d362dfec46131b4e66d5929b04e/igwn_segments-2.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f550aa55fa4165c1d3d8571d67523a010b5b365ee363516e542f732cbe54a1d8", size = 114953, upload-time = "2026-01-19T17:50:42.126Z" }, + { url = "https://files.pythonhosted.org/packages/61/05/b347e6ac22c5cbae568e410c56729a188d9d64ce340e859fc8db5b82f195/igwn_segments-2.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9aa802dfd96d24909975d92c24aee088d4be0fb3636a2d893cb196bbeb19f175", size = 115998, upload-time = "2026-01-19T17:50:43.169Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a5/f6a2b14786e62630fa3adad1b87155c4d71c3b1360e8563092793b81dfae/igwn_segments-2.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:74abd9ab5502bfba09343d2d6fcffb7c75f935385e1ca389b8ced82e9fe67a8f", size = 114540, upload-time = "2026-01-19T17:50:44.438Z" }, + { url = "https://files.pythonhosted.org/packages/48/9e/604b28e22207db111c9f43e4214a7aa086d6d65be38c024d393a5aaab781/igwn_segments-2.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:029a7d55c786ef24f1c07850a291e111ef2fceeaae2518347203dcaa019f811c", size = 114640, upload-time = "2026-01-19T17:50:45.749Z" }, +] + [[package]] name = "infer" version = "0.0.1" source = { editable = "." } dependencies = [ { name = "aframe" }, + { name = "gwpy", version = "3.0.14", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "gwpy", version = "4.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "jsonargparse" }, { name = "ledger" }, { name = "ml4gw-hermes", extra = ["torch"] }, @@ -1183,6 +1390,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aframe", editable = "../../" }, + { name = "gwpy" }, { name = "jsonargparse", specifier = "~=4.24" }, { name = "ledger", editable = "../../libs/ledger" }, { name = "ml4gw-hermes", extras = ["torch"], specifier = ">=0.2.1" }, @@ -1676,7 +1884,8 @@ name = "lalsuite" version = "7.25.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "astropy" }, + { name = "astropy", version = "6.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "astropy", version = "7.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "ligo-segments" }, { name = "lscsoft-glue" }, { name = "matplotlib" }, @@ -1748,6 +1957,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/37/c962f26408ce45271a5a3aaa918f7beae74a8e8a6f00bbe5fdf77fd778ca/ligo_segments-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:601be6d92e52bdebbb5a82b608ed1ceb781a0ea86b3e5333d61af77575b32664", size = 50555, upload-time = "2023-10-09T11:39:10.788Z" }, ] +[[package]] +name = "ligotimegps" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/39/2e/cef2ee4c4f3f1f04566e3e7e9343811a74f6e9a0bc6ef4711248f132e3bb/ligotimegps-2.0.1.tar.gz", hash = "sha256:88626c02ad9a464d1242a1147b40074792f424bafa2ab013eee629c7d1b6469c", size = 35191, upload-time = "2019-04-25T16:00:25.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/b6/6d6d0585fa2ae936a9f5d411b1f0fbe9fcb0aca0c51a775aa4f8f95fdf5e/ligotimegps-2.0.1-py2.py3-none-any.whl", hash = "sha256:da8c1289ba1310337ef5177e7936e25ce47d4e8e6a269cbdd5e9abfc5b5db490", size = 19930, upload-time = "2019-04-25T16:00:23.354Z" }, +] + +[[package]] +name = "ligotimegps" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/9b/521be61daa7603603826d40fa6327a22e04df4f20e69d9ed42182370a7f8/ligotimegps-2.1.0.tar.gz", hash = "sha256:d948ffc4d58472b303478a983ec56d2e6a7f35e54cbb351ca6f1292e68398cb2", size = 28069, upload-time = "2025-10-07T10:55:28.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/92/38674acb59c663bcdb69522e11afd3b4e6371f74b5764aeec2e99308cd79/ligotimegps-2.1.0-py3-none-any.whl", hash = "sha256:14dbbb07b175b94b4e1a519d7baa4548f0ea07bc71ef7d7f096ad8397d359043", size = 23370, upload-time = "2025-10-07T10:55:28.855Z" }, +] + [[package]] name = "lockfile" version = "0.12.2" @@ -2561,10 +2795,12 @@ name = "pycbc" version = "2.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "astropy" }, + { name = "astropy", version = "6.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "astropy", version = "7.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "beautifulsoup4" }, { name = "cython" }, - { name = "gwdatafind" }, + { name = "gwdatafind", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "gwdatafind", version = "2.1.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "h5py" }, { name = "jinja2" }, { name = "lalsuite" }, @@ -2834,6 +3070,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/97/42a550a79ab90ab37fcd8b519cd71bba4b96b85679218100d63b437770c0/python_rapidjson-1.20-cp312-cp312-win_amd64.whl", hash = "sha256:5d3be149ce5475f9605f01240487541057792abad94d3fd0cd56af363cf5a4dc", size = 149067, upload-time = "2024-08-05T17:55:49.834Z" }, ] +[[package]] +name = "pytz" +version = "2026.1.post1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/db/b8721d71d945e6a8ac63c0fc900b2067181dbb50805958d4d4661cf7d277/pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1", size = 321088, upload-time = "2026-03-03T07:47:50.683Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a", size = 510489, upload-time = "2026-03-03T07:47:49.167Z" }, +] + [[package]] name = "pywin32" version = "308" @@ -2971,6 +3216,63 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/b1/3baf80dc6d2b7bc27a95a67752d0208e410351e3feb4eb78de5f77454d8d/referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0", size = 26775, upload-time = "2025-01-25T08:48:14.241Z" }, ] +[[package]] +name = "regex" +version = "2026.2.28" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/71/41455aa99a5a5ac1eaf311f5d8efd9ce6433c03ac1e0962de163350d0d97/regex-2026.2.28.tar.gz", hash = "sha256:a729e47d418ea11d03469f321aaf67cdee8954cde3ff2cf8403ab87951ad10f2", size = 415184, upload-time = "2026-02-28T02:19:42.792Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/b8/845a927e078f5e5cc55d29f57becbfde0003d52806544531ab3f2da4503c/regex-2026.2.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fc48c500838be6882b32748f60a15229d2dea96e59ef341eaa96ec83538f498d", size = 488461, upload-time = "2026-02-28T02:15:48.405Z" }, + { url = "https://files.pythonhosted.org/packages/32/f9/8a0034716684e38a729210ded6222249f29978b24b684f448162ef21f204/regex-2026.2.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2afa673660928d0b63d84353c6c08a8a476ddfc4a47e11742949d182e6863ce8", size = 290774, upload-time = "2026-02-28T02:15:51.738Z" }, + { url = "https://files.pythonhosted.org/packages/a6/ba/b27feefffbb199528dd32667cd172ed484d9c197618c575f01217fbe6103/regex-2026.2.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7ab218076eb0944549e7fe74cf0e2b83a82edb27e81cc87411f76240865e04d5", size = 288737, upload-time = "2026-02-28T02:15:53.534Z" }, + { url = "https://files.pythonhosted.org/packages/18/c5/65379448ca3cbfe774fcc33774dc8295b1ee97dc3237ae3d3c7b27423c9d/regex-2026.2.28-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:94d63db12e45a9b9f064bfe4800cefefc7e5f182052e4c1b774d46a40ab1d9bb", size = 782675, upload-time = "2026-02-28T02:15:55.488Z" }, + { url = "https://files.pythonhosted.org/packages/aa/30/6fa55bef48090f900fbd4649333791fc3e6467380b9e775e741beeb3231f/regex-2026.2.28-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:195237dc327858a7721bf8b0bbbef797554bc13563c3591e91cd0767bacbe359", size = 850514, upload-time = "2026-02-28T02:15:57.509Z" }, + { url = "https://files.pythonhosted.org/packages/a9/28/9ca180fb3787a54150209754ac06a42409913571fa94994f340b3bba4e1e/regex-2026.2.28-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b387a0d092dac157fb026d737dde35ff3e49ef27f285343e7c6401851239df27", size = 896612, upload-time = "2026-02-28T02:15:59.682Z" }, + { url = "https://files.pythonhosted.org/packages/46/b5/f30d7d3936d6deecc3ea7bea4f7d3c5ee5124e7c8de372226e436b330a55/regex-2026.2.28-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3935174fa4d9f70525a4367aaff3cb8bc0548129d114260c29d9dfa4a5b41692", size = 791691, upload-time = "2026-02-28T02:16:01.752Z" }, + { url = "https://files.pythonhosted.org/packages/f5/34/96631bcf446a56ba0b2a7f684358a76855dfe315b7c2f89b35388494ede0/regex-2026.2.28-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b2b23587b26496ff5fd40df4278becdf386813ec00dc3533fa43a4cf0e2ad3c", size = 783111, upload-time = "2026-02-28T02:16:03.651Z" }, + { url = "https://files.pythonhosted.org/packages/39/54/f95cb7a85fe284d41cd2f3625e0f2ae30172b55dfd2af1d9b4eaef6259d7/regex-2026.2.28-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3b24bd7e9d85dc7c6a8bd2aa14ecd234274a0248335a02adeb25448aecdd420d", size = 767512, upload-time = "2026-02-28T02:16:05.616Z" }, + { url = "https://files.pythonhosted.org/packages/3d/af/a650f64a79c02a97f73f64d4e7fc4cc1984e64affab14075e7c1f9a2db34/regex-2026.2.28-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:bd477d5f79920338107f04aa645f094032d9e3030cc55be581df3d1ef61aa318", size = 773920, upload-time = "2026-02-28T02:16:08.325Z" }, + { url = "https://files.pythonhosted.org/packages/72/f8/3f9c2c2af37aedb3f5a1e7227f81bea065028785260d9cacc488e43e6997/regex-2026.2.28-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:b49eb78048c6354f49e91e4b77da21257fecb92256b6d599ae44403cab30b05b", size = 846681, upload-time = "2026-02-28T02:16:10.381Z" }, + { url = "https://files.pythonhosted.org/packages/54/12/8db04a334571359f4d127d8f89550917ec6561a2fddfd69cd91402b47482/regex-2026.2.28-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:a25c7701e4f7a70021db9aaf4a4a0a67033c6318752146e03d1b94d32006217e", size = 755565, upload-time = "2026-02-28T02:16:11.972Z" }, + { url = "https://files.pythonhosted.org/packages/da/bc/91c22f384d79324121b134c267a86ca90d11f8016aafb1dc5bee05890ee3/regex-2026.2.28-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:9dd450db6458387167e033cfa80887a34c99c81d26da1bf8b0b41bf8c9cac88e", size = 835789, upload-time = "2026-02-28T02:16:14.036Z" }, + { url = "https://files.pythonhosted.org/packages/46/a7/4cc94fd3af01dcfdf5a9ed75c8e15fd80fcd62cc46da7592b1749e9c35db/regex-2026.2.28-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2954379dd20752e82d22accf3ff465311cbb2bac6c1f92c4afd400e1757f7451", size = 780094, upload-time = "2026-02-28T02:16:15.468Z" }, + { url = "https://files.pythonhosted.org/packages/3c/21/e5a38f420af3c77cab4a65f0c3a55ec02ac9babf04479cfd282d356988a6/regex-2026.2.28-cp310-cp310-win32.whl", hash = "sha256:1f8b17be5c27a684ea6759983c13506bd77bfc7c0347dff41b18ce5ddd2ee09a", size = 266025, upload-time = "2026-02-28T02:16:16.828Z" }, + { url = "https://files.pythonhosted.org/packages/4d/0a/205c4c1466a36e04d90afcd01d8908bac327673050c7fe316b2416d99d3d/regex-2026.2.28-cp310-cp310-win_amd64.whl", hash = "sha256:dd8847c4978bc3c7e6c826fb745f5570e518b8459ac2892151ce6627c7bc00d5", size = 277965, upload-time = "2026-02-28T02:16:18.752Z" }, + { url = "https://files.pythonhosted.org/packages/c3/4d/29b58172f954b6ec2c5ed28529a65e9026ab96b4b7016bcd3858f1c31d3c/regex-2026.2.28-cp310-cp310-win_arm64.whl", hash = "sha256:73cdcdbba8028167ea81490c7f45280113e41db2c7afb65a276f4711fa3bcbff", size = 270336, upload-time = "2026-02-28T02:16:20.735Z" }, + { url = "https://files.pythonhosted.org/packages/04/db/8cbfd0ba3f302f2d09dd0019a9fcab74b63fee77a76c937d0e33161fb8c1/regex-2026.2.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e621fb7c8dc147419b28e1702f58a0177ff8308a76fa295c71f3e7827849f5d9", size = 488462, upload-time = "2026-02-28T02:16:22.616Z" }, + { url = "https://files.pythonhosted.org/packages/5d/10/ccc22c52802223f2368731964ddd117799e1390ffc39dbb31634a83022ee/regex-2026.2.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0d5bef2031cbf38757a0b0bc4298bb4824b6332d28edc16b39247228fbdbad97", size = 290774, upload-time = "2026-02-28T02:16:23.993Z" }, + { url = "https://files.pythonhosted.org/packages/62/b9/6796b3bf3101e64117201aaa3a5a030ec677ecf34b3cd6141b5d5c6c67d5/regex-2026.2.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bcb399ed84eabf4282587ba151f2732ad8168e66f1d3f85b1d038868fe547703", size = 288724, upload-time = "2026-02-28T02:16:25.403Z" }, + { url = "https://files.pythonhosted.org/packages/9c/02/291c0ae3f3a10cea941d0f5366da1843d8d1fa8a25b0671e20a0e454bb38/regex-2026.2.28-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7c1b34dfa72f826f535b20712afa9bb3ba580020e834f3c69866c5bddbf10098", size = 791924, upload-time = "2026-02-28T02:16:26.863Z" }, + { url = "https://files.pythonhosted.org/packages/0f/57/f0235cc520d9672742196c5c15098f8f703f2758d48d5a7465a56333e496/regex-2026.2.28-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:851fa70df44325e1e4cdb79c5e676e91a78147b1b543db2aec8734d2add30ec2", size = 860095, upload-time = "2026-02-28T02:16:28.772Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7c/393c94cbedda79a0f5f2435ebd01644aba0b338d327eb24b4aa5b8d6c07f/regex-2026.2.28-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:516604edd17b1c2c3e579cf4e9b25a53bf8fa6e7cedddf1127804d3e0140ca64", size = 906583, upload-time = "2026-02-28T02:16:30.977Z" }, + { url = "https://files.pythonhosted.org/packages/2c/73/a72820f47ca5abf2b5d911d0407ba5178fc52cf9780191ed3a54f5f419a2/regex-2026.2.28-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e7ce83654d1ab701cb619285a18a8e5a889c1216d746ddc710c914ca5fd71022", size = 800234, upload-time = "2026-02-28T02:16:32.55Z" }, + { url = "https://files.pythonhosted.org/packages/34/b3/6e6a4b7b31fa998c4cf159a12cbeaf356386fbd1a8be743b1e80a3da51e4/regex-2026.2.28-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f2791948f7c70bb9335a9102df45e93d428f4b8128020d85920223925d73b9e1", size = 772803, upload-time = "2026-02-28T02:16:34.029Z" }, + { url = "https://files.pythonhosted.org/packages/10/e7/5da0280c765d5a92af5e1cd324b3fe8464303189cbaa449de9a71910e273/regex-2026.2.28-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:03a83cc26aa2acda6b8b9dfe748cf9e84cbd390c424a1de34fdcef58961a297a", size = 781117, upload-time = "2026-02-28T02:16:36.253Z" }, + { url = "https://files.pythonhosted.org/packages/76/39/0b8d7efb256ae34e1b8157acc1afd8758048a1cf0196e1aec2e71fd99f4b/regex-2026.2.28-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ec6f5674c5dc836994f50f1186dd1fafde4be0666aae201ae2fcc3d29d8adf27", size = 854224, upload-time = "2026-02-28T02:16:38.119Z" }, + { url = "https://files.pythonhosted.org/packages/21/ff/a96d483ebe8fe6d1c67907729202313895d8de8495569ec319c6f29d0438/regex-2026.2.28-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:50c2fc924749543e0eacc93ada6aeeb3ea5f6715825624baa0dccaec771668ae", size = 761898, upload-time = "2026-02-28T02:16:40.333Z" }, + { url = "https://files.pythonhosted.org/packages/89/bd/d4f2e75cb4a54b484e796017e37c0d09d8a0a837de43d17e238adf163f4e/regex-2026.2.28-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:ba55c50f408fb5c346a3a02d2ce0ebc839784e24f7c9684fde328ff063c3cdea", size = 844832, upload-time = "2026-02-28T02:16:41.875Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a7/428a135cf5e15e4e11d1e696eb2bf968362f8ea8a5f237122e96bc2ae950/regex-2026.2.28-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:edb1b1b3a5576c56f08ac46f108c40333f222ebfd5cf63afdfa3aab0791ebe5b", size = 788347, upload-time = "2026-02-28T02:16:43.472Z" }, + { url = "https://files.pythonhosted.org/packages/a9/59/68691428851cf9c9c3707217ab1d9b47cfeec9d153a49919e6c368b9e926/regex-2026.2.28-cp311-cp311-win32.whl", hash = "sha256:948c12ef30ecedb128903c2c2678b339746eb7c689c5c21957c4a23950c96d15", size = 266033, upload-time = "2026-02-28T02:16:45.094Z" }, + { url = "https://files.pythonhosted.org/packages/42/8b/1483de1c57024e89296cbcceb9cccb3f625d416ddb46e570be185c9b05a9/regex-2026.2.28-cp311-cp311-win_amd64.whl", hash = "sha256:fd63453f10d29097cc3dc62d070746523973fb5aa1c66d25f8558bebd47fed61", size = 277978, upload-time = "2026-02-28T02:16:46.75Z" }, + { url = "https://files.pythonhosted.org/packages/a4/36/abec45dc6e7252e3dbc797120496e43bb5730a7abf0d9cb69340696a2f2d/regex-2026.2.28-cp311-cp311-win_arm64.whl", hash = "sha256:00f2b8d9615aa165fdff0a13f1a92049bfad555ee91e20d246a51aa0b556c60a", size = 270340, upload-time = "2026-02-28T02:16:48.626Z" }, + { url = "https://files.pythonhosted.org/packages/07/42/9061b03cf0fc4b5fa2c3984cbbaed54324377e440a5c5a29d29a72518d62/regex-2026.2.28-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fcf26c3c6d0da98fada8ae4ef0aa1c3405a431c0a77eb17306d38a89b02adcd7", size = 489574, upload-time = "2026-02-28T02:16:50.455Z" }, + { url = "https://files.pythonhosted.org/packages/77/83/0c8a5623a233015595e3da499c5a1c13720ac63c107897a6037bb97af248/regex-2026.2.28-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02473c954af35dd2defeb07e44182f5705b30ea3f351a7cbffa9177beb14da5d", size = 291426, upload-time = "2026-02-28T02:16:52.52Z" }, + { url = "https://files.pythonhosted.org/packages/9e/06/3ef1ac6910dc3295ebd71b1f9bfa737e82cfead211a18b319d45f85ddd09/regex-2026.2.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b65d33a17101569f86d9c5966a8b1d7fbf8afdda5a8aa219301b0a80f58cf7d", size = 289200, upload-time = "2026-02-28T02:16:54.08Z" }, + { url = "https://files.pythonhosted.org/packages/dd/c9/8cc8d850b35ab5650ff6756a1cb85286e2000b66c97520b29c1587455344/regex-2026.2.28-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e71dcecaa113eebcc96622c17692672c2d104b1d71ddf7adeda90da7ddeb26fc", size = 796765, upload-time = "2026-02-28T02:16:55.905Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5d/57702597627fc23278ebf36fbb497ac91c0ce7fec89ac6c81e420ca3e38c/regex-2026.2.28-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:481df4623fa4969c8b11f3433ed7d5e3dc9cec0f008356c3212b3933fb77e3d8", size = 863093, upload-time = "2026-02-28T02:16:58.094Z" }, + { url = "https://files.pythonhosted.org/packages/02/6d/f3ecad537ca2811b4d26b54ca848cf70e04fcfc138667c146a9f3157779c/regex-2026.2.28-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:64e7c6ad614573e0640f271e811a408d79a9e1fe62a46adb602f598df42a818d", size = 909455, upload-time = "2026-02-28T02:17:00.918Z" }, + { url = "https://files.pythonhosted.org/packages/9e/40/bb226f203caa22c1043c1ca79b36340156eca0f6a6742b46c3bb222a3a57/regex-2026.2.28-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6b08a06976ff4fb0d83077022fde3eca06c55432bb997d8c0495b9a4e9872f4", size = 802037, upload-time = "2026-02-28T02:17:02.842Z" }, + { url = "https://files.pythonhosted.org/packages/44/7c/c6d91d8911ac6803b45ca968e8e500c46934e58c0903cbc6d760ee817a0a/regex-2026.2.28-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:864cdd1a2ef5716b0ab468af40139e62ede1b3a53386b375ec0786bb6783fc05", size = 775113, upload-time = "2026-02-28T02:17:04.506Z" }, + { url = "https://files.pythonhosted.org/packages/dc/8d/4a9368d168d47abd4158580b8c848709667b1cd293ff0c0c277279543bd0/regex-2026.2.28-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:511f7419f7afab475fd4d639d4aedfc54205bcb0800066753ef68a59f0f330b5", size = 784194, upload-time = "2026-02-28T02:17:06.888Z" }, + { url = "https://files.pythonhosted.org/packages/cc/bf/2c72ab5d8b7be462cb1651b5cc333da1d0068740342f350fcca3bca31947/regex-2026.2.28-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:b42f7466e32bf15a961cf09f35fa6323cc72e64d3d2c990b10de1274a5da0a59", size = 856846, upload-time = "2026-02-28T02:17:09.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f4/6b65c979bb6d09f51bb2d2a7bc85de73c01ec73335d7ddd202dcb8cd1c8f/regex-2026.2.28-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8710d61737b0c0ce6836b1da7109f20d495e49b3809f30e27e9560be67a257bf", size = 763516, upload-time = "2026-02-28T02:17:11.004Z" }, + { url = "https://files.pythonhosted.org/packages/8e/32/29ea5e27400ee86d2cc2b4e80aa059df04eaf78b4f0c18576ae077aeff68/regex-2026.2.28-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4390c365fd2d45278f45afd4673cb90f7285f5701607e3ad4274df08e36140ae", size = 849278, upload-time = "2026-02-28T02:17:12.693Z" }, + { url = "https://files.pythonhosted.org/packages/1d/91/3233d03b5f865111cd517e1c95ee8b43e8b428d61fa73764a80c9bb6f537/regex-2026.2.28-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cb3b1db8ff6c7b8bf838ab05583ea15230cb2f678e569ab0e3a24d1e8320940b", size = 790068, upload-time = "2026-02-28T02:17:14.9Z" }, + { url = "https://files.pythonhosted.org/packages/76/92/abc706c1fb03b4580a09645b206a3fc032f5a9f457bc1a8038ac555658ab/regex-2026.2.28-cp312-cp312-win32.whl", hash = "sha256:f8ed9a5d4612df9d4de15878f0bc6aa7a268afbe5af21a3fdd97fa19516e978c", size = 266416, upload-time = "2026-02-28T02:17:17.15Z" }, + { url = "https://files.pythonhosted.org/packages/fa/06/2a6f7dff190e5fa9df9fb4acf2fdf17a1aa0f7f54596cba8de608db56b3a/regex-2026.2.28-cp312-cp312-win_amd64.whl", hash = "sha256:01d65fd24206c8e1e97e2e31b286c59009636c022eb5d003f52760b0f42155d4", size = 277297, upload-time = "2026-02-28T02:17:18.723Z" }, + { url = "https://files.pythonhosted.org/packages/b7/f0/58a2484851fadf284458fdbd728f580d55c1abac059ae9f048c63b92f427/regex-2026.2.28-cp312-cp312-win_arm64.whl", hash = "sha256:c0b5ccbb8ffb433939d248707d4a8b31993cb76ab1a0187ca886bf50e96df952", size = 270408, upload-time = "2026-02-28T02:17:20.328Z" }, +] + [[package]] name = "requests" version = "2.32.3" @@ -3476,6 +3778,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438, upload-time = "2024-06-07T18:52:13.582Z" }, ] +[[package]] +name = "tzdata" +version = "2025.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/a7/c202b344c5ca7daf398f3b8a477eeb205cf3b6f32e7ec3a6bac0629ca975/tzdata-2025.3.tar.gz", hash = "sha256:de39c2ca5dc7b0344f2eba86f49d614019d29f060fc4ebc8a417896a620b56a7", size = 196772, upload-time = "2025-12-13T17:45:35.667Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" }, +] + +[[package]] +name = "tzlocal" +version = "5.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/2e/c14812d3d4d9cd1773c6be938f89e5735a1f11a9f184ac3639b93cef35d5/tzlocal-5.3.1.tar.gz", hash = "sha256:cceffc7edecefea1f595541dbd6e990cb1ea3d19bf01b2809f362a03dd7921fd", size = 30761, upload-time = "2025-03-05T21:17:41.549Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/14/e2a54fabd4f08cd7af1c07030603c3356b74da07f7cc056e600436edfa17/tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d", size = 18026, upload-time = "2025-03-05T21:17:39.857Z" }, +] + [[package]] name = "uri-template" version = "1.3.0" @@ -3499,7 +3822,8 @@ name = "utils" version = "0.1.0" source = { editable = "../../libs/utils" } dependencies = [ - { name = "astropy" }, + { name = "astropy", version = "6.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "astropy", version = "7.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "h5py" }, { name = "ml4gw" }, { name = "numpy" }, From 9e3f1ca0ba8cbd96c56dd6cee3753569e5e6ce7c Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 17 Mar 2026 20:56:58 -0700 Subject: [PATCH 6/6] Address comments --- projects/infer/infer/data.py | 179 ++++++++++++++++++----------------- 1 file changed, 91 insertions(+), 88 deletions(-) diff --git a/projects/infer/infer/data.py b/projects/infer/infer/data.py index d1b0a93aa..56d052f18 100644 --- a/projects/infer/infer/data.py +++ b/projects/infer/infer/data.py @@ -75,38 +75,21 @@ def _initialize_sequence_state(self): self._done[seq_id] = False self._sequences[seq_id] = np.zeros(size) - def _finalize_sequence(self, has_foreground: bool = True): - if not self.done: - return None - - # if both the background and foreground - # sequences have completed, return them both, - # slicing off the dummy data from the last batch - background = self._sequences[self.id][self.slice] - foreground = None - if has_foreground: - foreground = self._sequences[self.id + 1][self.slice] - return background, foreground - - def _record_response(self, y, request_id, sequence_id): - # insert the response at the appropriate - # spot in the corresponding output array - start = request_id * self.batch_size - stop = (request_id + 1) * self.batch_size - self._sequences[sequence_id][start:stop] = y[:, 0] - - # indicate that the first response for - # this sequence has returned, and possibly - # that the last one has returned as well - self._started[sequence_id] = True - if request_id == len(self) - 1: - self._done[sequence_id] = True - @abstractmethod def _setup(self, **kwargs): """Subclasses must set sample_rate, size, t0, duration, and shifts.""" pass + @property + @abstractmethod + def inference_filenames(self): + pass + + @property + @abstractmethod + def has_foreground(self): + pass + @property def started(self): return all(self._started.values()) @@ -147,6 +130,60 @@ def __len__(self): # corresponding to the excess return math.ceil((self.size - max(self.shifts)) / self.step_size) + def _finalize_sequence(self, has_foreground: bool = True): + if not self.done: + return None + + # if both the background and foreground + # sequences have completed, return them both, + # slicing off the dummy data from the last batch + background = self._sequences[self.id][self.slice] + foreground = None + if has_foreground: + foreground = self._sequences[self.id + 1][self.slice] + return background, foreground + + def _record_response(self, y, request_id, sequence_id): + # insert the response at the appropriate + # spot in the corresponding output array + start = request_id * self.batch_size + stop = (request_id + 1) * self.batch_size + self._sequences[sequence_id][start:stop] = y[:, 0] + + # indicate that the first response for + # this sequence has returned, and possibly + # that the last one has returned as well + self._started[sequence_id] = True + if request_id == len(self) - 1: + self._done[sequence_id] = True + + def __call__(self, y, request_id, sequence_id): + self._record_response(y, request_id, sequence_id) + return self._finalize_sequence(has_foreground=self.has_foreground) + + def _get_data_indices(self, batch_idx: int, shift: int = 0): + # if this is the last batch, we may need to pad it + # to make it a full batch + last = batch_idx == len(self) - 1 + # grab the current batch of updates from the file + # and stack it into a 2D array + start = batch_idx * self.step_size + shift + + # for all but last batch just + # increase by step size + end = start + self.step_size + + # if this is the last batch + # and we need to pad it + # just step by the remainder + if last and self.remainder: + end = start + self.remainder + + return start, end, last + + def _pad_last_batch(self, data: np.ndarray): + return np.pad(data, ((0, 0), (0, self.num_pad)), "constant") + class Hdf5Sequence(BaseSequence): def __init__( @@ -208,7 +245,6 @@ def _setup( shifts: list[float], ): self.background_fname = background_fname - self.inference_filenames = [background_fname] self.ifos = ifos if len(ifos) != len(shifts): @@ -253,36 +289,26 @@ def _setup( self.injection_set = injection_set self.shifts = np.array([int(i * self.sample_rate) for i in shifts]) + @property + def inference_filenames(self): + return [self.background_fname] + + @property + def has_foreground(self): + return self.injection_set is not None + def __iter__(self): with h5py.File(self.background_fname, "r") as f: for i in range(len(self)): - # if this is the last batch, we may need to pad it - # to make it a full batch - last = i == len(self) - 1 - # grab the current batch of updates from the file - # and stack it into a 2D array x = [] for ifo, shift in zip(self.ifos, self.shifts, strict=True): - start = shift + i * self.step_size - - # for all but last batch just - # increase by step size - end = start + self.step_size - - # if this is the last batch - # and we need to pad it - # just step by the remainder - if last and self.remainder: - end = start + self.remainder - + start, end, last = self._get_data_indices(i, shift) data = f[ifo][start:end] - # if this is the last batch - # possibly pad it to make it a full batch - if last: - data = np.pad(data, (0, self.num_pad), "constant") - x.append(data) + x = np.stack(x).astype(np.float32) + x = self._pad_last_batch(x) if last else x + # if there are any injections for this shift, # inject waveforms into a copy of the background x_inj = None @@ -297,12 +323,6 @@ def __iter__(self): with self.limiter: yield x, x_inj - def __call__(self, y, request_id, sequence_id): - self._record_response(y, request_id, sequence_id) - return self._finalize_sequence( - has_foreground=self.injection_set is not None - ) - def recover(self, foreground: EventSet) -> RecoveredInjectionSet: return RecoveredInjectionSet.recover(foreground, self.injection_set) @@ -370,11 +390,11 @@ def _setup( if not injection_files: raise ValueError("Must provide at least one injection file") - injection_files = sorted(injection_files) - - self.inference_filenames = [fname.name for fname in injection_files] + self.injection_files = sorted(injection_files) - matches = [FNAME_RE.search(fname.name) for fname in injection_files] + matches = [ + FNAME_RE.search(fname.name) for fname in self.injection_files + ] if not all(matches): raise ValueError( "All injection files must match expected name pattern" @@ -386,7 +406,7 @@ def _setup( self.t0 = min(starts) self.duration = sum(durations) self.size = int(self.duration * self.sample_rate) - self.timeseries = np.zeros((len(ifos), self.size)) + self.timeseries = np.zeros((len(ifos), self.size), dtype=np.float32) # Load and resample data from each file for file, start, duration in zip( @@ -400,42 +420,25 @@ def _setup( [injected[ch].value for ch in self.channels] ) + @property + def inference_filenames(self): + return [fname.name for fname in self.injection_files] + + @property + def has_foreground(self): + return True + def __iter__(self): for i in range(len(self)): - # if this is the last batch, we may need to pad it - # to make it a full batch - last = i == len(self) - 1 - # grab the current batch of updates from the file - # and stack it into a 2D array - start = i * self.step_size - - # for all but last batch just - # increase by step size - end = start + self.step_size - - # if this is the last batch - # and we need to pad it - # just step by the remainder - if last and self.remainder: - end = start + self.remainder + start, end, last = self._get_data_indices(i) x_inj = self.timeseries[:, start:end] - - if last: - x_inj = np.pad(x_inj, ((0, 0), (0, self.num_pad)), "constant") - - # TODO: Do we need this type conversion? - x_inj = x_inj.astype(np.float32) + x_inj = self._pad_last_batch(x_inj) if last else x_inj # yield the same data twice, as we normally # expect to pass background and foreground with self.limiter: yield x_inj, x_inj - def __call__(self, y, request_id, sequence_id): - self._record_response(y, request_id, sequence_id) - return self._finalize_sequence(has_foreground=True) - - # TODO: format the foreground into whatever R&P expects def recover(self, foreground: EventSet) -> EventSet: return foreground