Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aframe/tasks/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .infer import Infer
from .infer import Infer, InferRnP
240 changes: 167 additions & 73 deletions aframe/tasks/infer/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import warnings
from abc import abstractmethod
from pathlib import Path

import h5py
Expand All @@ -13,7 +14,7 @@
from aframe.base import AframeSingularityTask
from aframe.config import paths
from aframe.parameters import PathParameter
from aframe.tasks import ExportLocal, TestingWaveforms
from aframe.tasks import ExportLocal, TestingWaveforms, Train
from aframe.tasks.data.condor.workflows import StaticMemoryWorkflow


Expand Down Expand Up @@ -115,77 +116,38 @@ def timeseries_output(self):
def metadata_output(self):
return self.tmp_dir / "metadata.json"

@property
def background_fnames(self):
return self.workflow_input()["data"].collection.targets.values()
def workflow_requires(self):
reqs = {
"model_repository": ExportLocal.req(
self, train_task=self.train_task
)
}
reqs.update(self._workflow_requires())
return reqs

@property
def injection_set_fname(self):
return self.workflow_input()["waveforms"][0].path
@abstractmethod
def _workflow_requires(self):
pass

def workflow_requires(self):
reqs = {}
reqs["model_repository"] = ExportLocal.req(self)
testing_waveforms = TestingWaveforms.req(self)
fetch = testing_waveforms.requires().workflow_requires()[
"test_segments"
]
reqs["data"] = fetch
reqs["waveforms"] = testing_waveforms
@abstractmethod
def _workflow_condition(self) -> bool:
pass

return reqs
@abstractmethod
def _create_branch_map(self):
pass

def get_num_shifts(self):
# calculate the number of shifts required
# to accumulate the requested background,
# given the duration of the background segments
segments = data_utils.segments_from_paths(self.background_fnames)
num_shifts = data_utils.get_num_shifts_from_Tb(
segments,
self.Tb,
max(self.shifts),
self.psd_length,
)
return num_shifts
@abstractmethod
def create_sequence(self):
pass

@law.dynamic_workflow_condition(cache_met_condition=True)
def workflow_condition(self) -> bool:
return self.workflow_input()["data"].collection.exists()
return self._workflow_condition()

@workflow_condition.create_branch_map
def create_branch_map(self):
# create the individual fname shift
# combinations that represent individual
# condor inference jobs to be submitted
branch_map = {}
num_shifts = self.get_num_shifts()
counter = 0
for fname in self.background_fnames:
fname = Path(fname.path)
start, duration = map(float, fname.stem.split("-")[-2:])
stop = start + duration

if self.zero_lag:
# check if segment is long enough to be analyzed
if data_utils.is_analyzeable_segment(
start, stop, [0] * len(self.shifts), self.psd_length
):
_shifts = [0 for s in self.shifts]
branch_map[counter] = (fname, _shifts)
counter += 1

if num_shifts > 0:
for i in range(num_shifts):
_shifts = [s * (i + 1) for s in self.shifts]
# check if segment is long enough to be analyzed
if data_utils.is_analyzeable_segment(
start, stop, _shifts, self.psd_length
):
# unique identifier for mapping to branch map
branch_map[counter] = (fname, _shifts)
counter += 1

return branch_map
return self._create_branch_map()

def get_ip_address(self) -> str:
raise NotImplementedError
Expand All @@ -203,22 +165,12 @@ def output(self):
def run(self):
from hermes.aeriel.client import InferenceClient

from infer.data import Sequence
from infer.main import infer
from infer.postprocess import Postprocessor

ip = os.getenv("AFRAME_TRITON_IP")
self.tmp_dir.mkdir(exist_ok=True, parents=True)
fname, shifts = self.branch_data
sequence = Sequence(
ifos=self.ifos,
batch_size=self.batch_size,
inference_sampling_rate=self.inference_sampling_rate,
rate=self.rate_per_client,
shifts=shifts,
background_fname=fname,
injection_set_fname=self.injection_set_fname,
)
sequence, shifts = self.create_sequence()

postprocessor = Postprocessor(
integration_window_length=self.integration_window_length,
Expand Down Expand Up @@ -267,3 +219,145 @@ def run(self):
}
with open(self.metadata_output, "w") as f:
json.dump(metadata, f)


@inherits(InferParameters)
class Hdf5InferBase(InferBase):
@property
def background_fnames(self):
return self.workflow_input()["data"].collection.targets.values()

@property
def injection_set_fname(self):
return self.workflow_input()["waveforms"][0].path

def _workflow_requires(self):
reqs = {}
testing_waveforms = TestingWaveforms.req(self)
fetch = testing_waveforms.requires().workflow_requires()[
"test_segments"
]
reqs["data"] = fetch
reqs["waveforms"] = testing_waveforms
return reqs

def get_num_shifts(self):
# calculate the number of shifts required
# to accumulate the requested background,
# given the duration of the background segments
segments = data_utils.segments_from_paths(self.background_fnames)
num_shifts = data_utils.get_num_shifts_from_Tb(
segments,
self.Tb,
max(self.shifts),
self.psd_length,
)
return num_shifts

def _workflow_condition(self) -> bool:
return self.workflow_input()["data"].collection.exists()

def _create_branch_map(self):
# create the individual fname shift
# combinations that represent individual
# condor inference jobs to be submitted
branch_map = {}
num_shifts = self.get_num_shifts()
counter = 0
for fname in self.background_fnames:
fname = Path(fname.path)
start, duration = map(float, fname.stem.split("-")[-2:])
stop = start + duration

if self.zero_lag:
# check if segment is long enough to be analyzed
if data_utils.is_analyzeable_segment(
start, stop, [0] * len(self.shifts), self.psd_length
):
_shifts = [0 for _ in self.shifts]
branch_map[counter] = (fname, _shifts)
counter += 1

if num_shifts > 0:
for i in range(num_shifts):
_shifts = [s * (i + 1) for s in self.shifts]
# check if segment is long enough to be analyzed
if data_utils.is_analyzeable_segment(
start, stop, _shifts, self.psd_length
):
# unique identifier for mapping to branch map
branch_map[counter] = (fname, _shifts)
counter += 1

return branch_map

def create_sequence(self):
from infer.data import Hdf5Sequence

fname, shifts = self.branch_data
sequence = Hdf5Sequence(
ifos=self.ifos,
batch_size=self.batch_size,
inference_sampling_rate=self.inference_sampling_rate,
rate=self.rate_per_client,
shifts=shifts,
background_fname=str(fname),
injection_set_fname=str(self.injection_set_fname),
)
return sequence, shifts


@inherits(InferParameters)
class RnPInferBase(InferBase):
train_task = Train

injection_file_dir = luigi.PathParameter()
channel = luigi.Parameter(
description="Name of channel within injection files"
)
sample_rate = luigi.FloatParameter(default=2048.0)
files_per_job = luigi.IntParameter(
default=1,
description="Number of injection files to process per condor job.",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.injection_files = list(
Path(self.injection_file_dir).rglob("*.gwf")
)

def _workflow_requires(self):
return {}

def _workflow_condition(self) -> bool:
return bool(self.injection_files) and all(
f.exists() for f in self.injection_files
)

def _create_branch_map(self):
if not self.injection_files:
return {}

shifts = [0.0 for _ in self.ifos]
files = sorted(self.injection_files)
branch_map = {}
for i, start in enumerate(range(0, len(files), self.files_per_job)):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So compared to the Hdf5 sequence, here we analyze multiple files at once?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, all the frame files are 512 seconds long. There are a lot of them, so cropping the PSD length from the start of each would remove a lot of signals if we analyzed each frame individually.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it makes sense - how are science quality segments handled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the frames are organized into chunks, which cover different time periods. All the frames within a chunk are continuous, and chunks are continuous with each other. I haven't what specifically is done for non-science-ready times, though I think they might just zero them out. But that shouldn't break anything, I think.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use the science segments we query to group the frames accordingly?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think there's a bit more logic that could be implemented for grouping frames. We can definitely query the science segments and cut out any frames that are fully not ready, though we should still analyze partially ready frames. The chunks also repeat with different CBC populations, so we could avoid concatenating frames that go over a border.

chunk = files[start : start + self.files_per_job]
branch_map[i] = (chunk, shifts)
return branch_map

def create_sequence(self):
from infer.data import RnPSequence

injection_files, shifts = self.branch_data
sequence = RnPSequence(
injection_files=injection_files,
channel=self.channel,
ifos=self.ifos,
sample_rate=self.sample_rate,
inference_sampling_rate=self.inference_sampling_rate,
batch_size=self.batch_size,
rate=self.rate_per_client,
)
return sequence, shifts
Loading
Loading