diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index d6d17a64..27a814be 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -56,18 +56,39 @@ @dataclass class DesignGenerator: - """Lazy callable that imports source_path and calls fn_name(*args, **kwargs), returning MLIR as a string.""" + """Lazy callable that imports *source_path* and calls *fn_name*, returning MLIR. + + On first call the loaded module is registered in ``sys.modules`` under + its fully-qualified package name (e.g. ``iron.operators.axpy.design``) so + that subsequent code can import it through the standard Python import + mechanism. + """ source_path: Path fn_name: str args: tuple = () kwargs: dict[str, Any] = field(default_factory=dict) + @property + def module_name(self) -> str: + """Python module path derived from *source_path*. + + Walks up from *source_path* until the ``iron`` package root is found, + then builds a dotted path relative to it. + """ + parts = list(self.source_path.resolve().parts) + try: + idx = parts.index("iron") + except ValueError: + return self.source_path.stem + rel_parts = parts[idx:-1] + [self.source_path.stem] + return ".".join(rel_parts) + def __call__(self) -> str: - spec = importlib.util.spec_from_file_location( - self.source_path.name, self.source_path - ) + module_name = self.module_name + spec = importlib.util.spec_from_file_location(module_name, self.source_path) module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module spec.loader.exec_module(module) return str(getattr(module, self.fn_name)(*self.args, **self.kwargs)) diff --git a/iron/common/compilation/fusion.py b/iron/common/compilation/fusion.py index fdb6c8a8..fce22f80 100644 --- a/iron/common/compilation/fusion.py +++ b/iron/common/compilation/fusion.py @@ -8,7 +8,7 @@ from __future__ import annotations import numpy as np -import importlib.util +import importlib from functools import partial from pathlib import Path from aie import ir @@ -74,18 +74,17 @@ def extract_runtime_sequence_arg_types(dev_op: Any) -> list[Any]: def get_child_mlir_module(mlir_artifact: PythonGeneratedMLIRArtifact) -> Any: """Extract MLIR module from a PythonGeneratedMLIRArtifact. - Uses the artifact's DesignGenerator to dynamically import the design - module and call the callback, returning the raw (non-stringified) MLIR - module object for further inspection by the fusion pass. + Imports the design module via the standard Python import mechanism + (relying on it having been registered in ``sys.modules`` by + ``DesignGenerator.__call__``) and calls the callback, returning the + raw (non-stringified) MLIR module object. """ if not isinstance(mlir_artifact, PythonGeneratedMLIRArtifact): raise TypeError( f"Expected PythonGeneratedMLIRArtifact, got {type(mlir_artifact).__name__}" ) gen = mlir_artifact.generator - spec = importlib.util.spec_from_file_location(gen.source_path.name, gen.source_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + module = importlib.import_module(gen.module_name) callback_function = getattr(module, gen.fn_name) return callback_function(*gen.args, **gen.kwargs)