Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/romitask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
... return processed_data
"""

import concurrent.futures
import glob
import json
import os.path
Expand Down Expand Up @@ -845,6 +846,85 @@ def run(self):
return


class ParallelFileTask(RomiTask):
"""Abstract task to parallelly apply a function to each ``File`` of a ``Fileset``.

Attributes
----------
upstream_task : luigi.TaskParameter
The upstream task.
scan_id : luigi.Parameter, optional
The scan id to use to get or create the ``FilesetTarget``.
If unspecified (default), the current active scan will be used.
query : luigi.DictParameter, optional
A filtering dictionary to apply on input ```Fileset`` metadata.
Key(s) and value(s) must be found in metadata to select the ``File``.
By default, no filtering is performed, all inputs are used.
n_workers : luigi.IntParameter, optional
Number of worker threads to use for parallel processing.
Defaults to ``-1``, which uses the default ``ThreadPoolExecutor`` behavior.
parallel : luigi.BoolParameter, optional
Flag to enable/disable parallel processing.
Defaults to ``True``.

Notes
-----
Input `File`s metadata are copied to the target/output `File`s metadata.
This task runs the processing in parallel using ``ThreadPoolExecutor``.

"""
query = luigi.DictParameter(default={})
n_workers = luigi.IntParameter(default=-1)
parallel = luigi.BoolParameter(default=True)

def f(self, f, outfs):
"""Function applied to every file in the fileset must return a file object.

Parameters
----------
f: plantdb.commons.fsdb.FSDB.File
Input file.
outfs: plantdb.commons.fsdb.FSDB.Fileset
Output fileset.

Returns
-------
plantdb.commons.fsdb.FSDB.File
This file must be created in `outfs`.
"""
raise NotImplementedError

def run(self, input_fileset, output_fileset):
"""Run the task on every `File`s from a `Fileset` that fulfill the ``query`` in parallel."""
in_files = input_fileset.get_files(query=self.query)
logger.debug(f"Got {len(in_files)} input files:")
logger.debug(f"{', '.join([f.id for f in in_files])}")
logger.debug(f"Got a filtering query: '{self.query}'.")

# Helper function for parallel processing
def _process_file(fi):
outfi = self.f(fi, output_fileset)
if outfi is not None:
m = fi.get_metadata()
outm = outfi.get_metadata()
outfi.set_metadata({**m, **outm})
return outfi

self.n_workers = None if self.n_workers == -1 else self.n_workers
self.parallel = self.n_workers != 1
if not self.parallel:
# Sequential processing when parallel is disabled
for fi in tqdm(in_files, unit="file"):
_process_file(fi)
else:
# Parallel processing
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_workers) as executor:
list(tqdm(executor.map(_process_file, in_files),
total=len(in_files), unit="file"))

return


@RomiTask.event_handler(luigi.Event.FAILURE)
def mourn_failure(task, exception):
"""In the case of failure of a task, remove the corresponding fileset from the database.
Expand Down
Loading