diff --git a/src/romitask/task.py b/src/romitask/task.py index 890df48..6bc272b 100644 --- a/src/romitask/task.py +++ b/src/romitask/task.py @@ -66,6 +66,7 @@ ... return processed_data """ +import concurrent.futures import glob import json import os.path @@ -845,6 +846,85 @@ def run(self): return +class ParallelFileTask(RomiTask): + """Abstract task to parallelly apply a function to each ``File`` of a ``Fileset``. + + Attributes + ---------- + upstream_task : luigi.TaskParameter + The upstream task. + scan_id : luigi.Parameter, optional + The scan id to use to get or create the ``FilesetTarget``. + If unspecified (default), the current active scan will be used. + query : luigi.DictParameter, optional + A filtering dictionary to apply on input ```Fileset`` metadata. + Key(s) and value(s) must be found in metadata to select the ``File``. + By default, no filtering is performed, all inputs are used. + n_workers : luigi.IntParameter, optional + Number of worker threads to use for parallel processing. + Defaults to ``-1``, which uses the default ``ThreadPoolExecutor`` behavior. + parallel : luigi.BoolParameter, optional + Flag to enable/disable parallel processing. + Defaults to ``True``. + + Notes + ----- + Input `File`s metadata are copied to the target/output `File`s metadata. + This task runs the processing in parallel using ``ThreadPoolExecutor``. + + """ + query = luigi.DictParameter(default={}) + n_workers = luigi.IntParameter(default=-1) + parallel = luigi.BoolParameter(default=True) + + def f(self, f, outfs): + """Function applied to every file in the fileset must return a file object. + + Parameters + ---------- + f: plantdb.commons.fsdb.FSDB.File + Input file. + outfs: plantdb.commons.fsdb.FSDB.Fileset + Output fileset. + + Returns + ------- + plantdb.commons.fsdb.FSDB.File + This file must be created in `outfs`. + """ + raise NotImplementedError + + def run(self, input_fileset, output_fileset): + """Run the task on every `File`s from a `Fileset` that fulfill the ``query`` in parallel.""" + in_files = input_fileset.get_files(query=self.query) + logger.debug(f"Got {len(in_files)} input files:") + logger.debug(f"{', '.join([f.id for f in in_files])}") + logger.debug(f"Got a filtering query: '{self.query}'.") + + # Helper function for parallel processing + def _process_file(fi): + outfi = self.f(fi, output_fileset) + if outfi is not None: + m = fi.get_metadata() + outm = outfi.get_metadata() + outfi.set_metadata({**m, **outm}) + return outfi + + self.n_workers = None if self.n_workers == -1 else self.n_workers + self.parallel = self.n_workers != 1 + if not self.parallel: + # Sequential processing when parallel is disabled + for fi in tqdm(in_files, unit="file"): + _process_file(fi) + else: + # Parallel processing + with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_workers) as executor: + list(tqdm(executor.map(_process_file, in_files), + total=len(in_files), unit="file")) + + return + + @RomiTask.event_handler(luigi.Event.FAILURE) def mourn_failure(task, exception): """In the case of failure of a task, remove the corresponding fileset from the database.