From 9f15b1f719e9798e17c8f1b1d08f7c8fcf3cee96 Mon Sep 17 00:00:00 2001 From: masteryi-0018 <1536474741@qq.com> Date: Tue, 16 Jun 2026 19:42:02 +0800 Subject: [PATCH] Register design modules in sys.modules to fix path wrangling (#42) DesignGenerator.__call__ now registers the loaded module under its fully-qualified package name (e.g. iron.operators.axpy.design) so that subsequent code can import it through the standard Python import mechanism instead of duplicating spec_from_file_location calls. get_child_mlir_module in fusion.py now uses importlib.import_module to retrieve the already-registered module. Closes #42 Co-Authored-By: Claude Opus 4.8 --- iron/common/compilation/base.py | 29 +++++++++++++++++++++++++---- iron/common/compilation/fusion.py | 13 ++++++------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index d6d17a64..27a814be 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -56,18 +56,39 @@ @dataclass class DesignGenerator: - """Lazy callable that imports source_path and calls fn_name(*args, **kwargs), returning MLIR as a string.""" + """Lazy callable that imports *source_path* and calls *fn_name*, returning MLIR. + + On first call the loaded module is registered in ``sys.modules`` under + its fully-qualified package name (e.g. ``iron.operators.axpy.design``) so + that subsequent code can import it through the standard Python import + mechanism. + """ source_path: Path fn_name: str args: tuple = () kwargs: dict[str, Any] = field(default_factory=dict) + @property + def module_name(self) -> str: + """Python module path derived from *source_path*. + + Walks up from *source_path* until the ``iron`` package root is found, + then builds a dotted path relative to it. + """ + parts = list(self.source_path.resolve().parts) + try: + idx = parts.index("iron") + except ValueError: + return self.source_path.stem + rel_parts = parts[idx:-1] + [self.source_path.stem] + return ".".join(rel_parts) + def __call__(self) -> str: - spec = importlib.util.spec_from_file_location( - self.source_path.name, self.source_path - ) + module_name = self.module_name + spec = importlib.util.spec_from_file_location(module_name, self.source_path) module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module spec.loader.exec_module(module) return str(getattr(module, self.fn_name)(*self.args, **self.kwargs)) diff --git a/iron/common/compilation/fusion.py b/iron/common/compilation/fusion.py index fdb6c8a8..fce22f80 100644 --- a/iron/common/compilation/fusion.py +++ b/iron/common/compilation/fusion.py @@ -8,7 +8,7 @@ from __future__ import annotations import numpy as np -import importlib.util +import importlib from functools import partial from pathlib import Path from aie import ir @@ -74,18 +74,17 @@ def extract_runtime_sequence_arg_types(dev_op: Any) -> list[Any]: def get_child_mlir_module(mlir_artifact: PythonGeneratedMLIRArtifact) -> Any: """Extract MLIR module from a PythonGeneratedMLIRArtifact. - Uses the artifact's DesignGenerator to dynamically import the design - module and call the callback, returning the raw (non-stringified) MLIR - module object for further inspection by the fusion pass. + Imports the design module via the standard Python import mechanism + (relying on it having been registered in ``sys.modules`` by + ``DesignGenerator.__call__``) and calls the callback, returning the + raw (non-stringified) MLIR module object. """ if not isinstance(mlir_artifact, PythonGeneratedMLIRArtifact): raise TypeError( f"Expected PythonGeneratedMLIRArtifact, got {type(mlir_artifact).__name__}" ) gen = mlir_artifact.generator - spec = importlib.util.spec_from_file_location(gen.source_path.name, gen.source_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + module = importlib.import_module(gen.module_name) callback_function = getattr(module, gen.fn_name) return callback_function(*gen.args, **gen.kwargs)